Example #1
0
def train():
    # initialize datasets and loaders
    trainsets, valsets, testsets = args['data.train'], args['data.val'], args[
        'data.test']
    train_loader = MetaDatasetBatchReader('train',
                                          trainsets,
                                          valsets,
                                          testsets,
                                          batch_size=args['train.batch_size'])
    val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets)

    # initialize model and optimizer
    num_train_classes = sum(list(train_loader.dataset_to_n_cats.values()))
    model = get_model(num_train_classes, args)
    optimizer = get_optimizer(model, args, params=model.get_parameters())

    # Restoring the last checkpoint
    checkpointer = CheckPointer(args, model, optimizer=optimizer)
    if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']:
        start_iter, best_val_loss, best_val_acc =\
            checkpointer.restore_model(ckpt='last', strict=False)
    else:
        print('No checkpoint restoration')
        best_val_loss = 999999999
        best_val_acc = start_iter = 0

    # define learning rate policy
    if args['train.lr_policy'] == "step":
        lr_manager = UniformStepLR(optimizer, args, start_iter)
    elif "exp_decay" in args['train.lr_policy']:
        lr_manager = ExpDecayLR(optimizer, args, start_iter)
    elif "cosine" in args['train.lr_policy']:
        lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter)

    # defining the summary writer
    writer = SummaryWriter(checkpointer.model_path)

    # Training loop
    max_iter = args['train.max_iter']
    epoch_loss = {name: [] for name in trainsets}
    epoch_acc = {name: [] for name in trainsets}
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as session:
        for i in tqdm(range(max_iter)):
            if i < start_iter:
                continue

            optimizer.zero_grad()
            sample = train_loader.get_train_batch(session)
            batch_dataset = sample['dataset_name']
            dataset_id = sample['dataset_ids'][0].detach().cpu().item()
            logits = model.forward(sample['images'])
            labels = sample['labels']
            batch_loss, stats_dict, _ = cross_entropy_loss(logits, labels)
            epoch_loss[batch_dataset].append(stats_dict['loss'])
            epoch_acc[batch_dataset].append(stats_dict['acc'])

            batch_loss.backward()
            optimizer.step()
            lr_manager.step(i)

            if (i + 1) % 200 == 0:
                for dataset_name in trainsets:
                    writer.add_scalar(f"loss/{dataset_name}-train_acc",
                                      np.mean(epoch_loss[dataset_name]), i)
                    writer.add_scalar(f"accuracy/{dataset_name}-train_acc",
                                      np.mean(epoch_acc[dataset_name]), i)
                    epoch_loss[dataset_name], epoch_acc[dataset_name] = [], []

                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], i)

            # Evaluation inside the training loop
            if (i + 1) % args['train.eval_freq'] == 0:
                model.eval()
                dataset_accs, dataset_losses = [], []
                for valset in valsets:
                    dataset_id = train_loader.dataset_name_to_dataset_id[
                        valset]
                    val_losses, val_accs = [], []
                    for j in tqdm(range(args['train.eval_size'])):
                        with torch.no_grad():
                            sample = val_loader.get_validation_task(
                                session, valset)
                            context_features = model.embed(
                                sample['context_images'])
                            target_features = model.embed(
                                sample['target_images'])
                            context_labels = sample['context_labels']
                            target_labels = sample['target_labels']
                            _, stats_dict, _ = prototype_loss(
                                context_features, context_labels,
                                target_features, target_labels)
                        val_losses.append(stats_dict['loss'])
                        val_accs.append(stats_dict['acc'])

                    # write summaries per validation set
                    dataset_acc, dataset_loss = np.mean(
                        val_accs) * 100, np.mean(val_losses)
                    dataset_accs.append(dataset_acc)
                    dataset_losses.append(dataset_loss)
                    writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss,
                                      i)
                    writer.add_scalar(f"accuracy/{valset}/val_acc",
                                      dataset_acc, i)
                    print(
                        f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}"
                    )

                # write summaries averaged over datasets
                avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean(
                    dataset_accs)
                writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i)
                writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i)

                # saving checkpoints
                if avg_val_acc > best_val_acc:
                    best_val_loss, best_val_acc = avg_val_loss, avg_val_acc
                    is_best = True
                    print('Best model so far!')
                else:
                    is_best = False
                checkpointer.save_checkpoint(i,
                                             best_val_acc,
                                             best_val_loss,
                                             is_best,
                                             optimizer=optimizer,
                                             state_dict=model.get_state_dict())

                model.train()
                print(f"Trained and evaluated at {i}")

    writer.close()
    if start_iter < max_iter:
        print(
            f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%"""
        )
    else:
        print(
            f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}"""
        )
Example #2
0
def main(args):

    args.sample_size = 28
    args.validation_split = 0.0

    # data for testing
    if args.dataset == 'MNIST':
        args.sample_size = 28
    elif args.dataset == 'SOP' or 'Shopee':
        if args.model == 'resnet':
            args.sample_size = 224
        elif args.model == 'vgg' or args.model == 'vgg_attn':
            args.sample_size = 224
        elif args.model == 'inception':
            args.sample_size = 299
        else:
            args.sample_size = 224

    spatial_transform_test = get_test_transform(args)
    crop_transform = get_crop_transform(args)

    if args.dataset == 'MNIST':
        test_data_loader = data_loading.CroppedMNISTLoader(
            args,
            crop_transform=crop_transform,
            spatial_transform=spatial_transform_test,
            training=False)

    elif args.dataset == 'SOP':
        test_data_loader = data_loading.CroppedSOPLoader(
            args,
            crop_transform=crop_transform,
            spatial_transform=spatial_transform_test,
            training=False)

    elif args.dataset == 'Shopee':
        test_data_loader = data_loading.ShopeeDataLoader(
            args,
            crop_transform=crop_transform,
            spatial_transform=spatial_transform_test,
            training=False)

    args.n_classes = test_data_loader.n_classes
    data_ratio = sum([1 for sample in test_data_loader.dataset.samples if sample[1] == 0]) / \
                 sum([1 for sample in test_data_loader.dataset.samples if sample[1] == 1])
    print("normal data/cropped data ratio: {}".format(data_ratio))

    # prepare the model for testing
    model, parameters = get_model(args)
    model = model.to(device)
    model.eval()

    test_logger = Logger(
        os.path.join(args.log_path, 'test_{}.log'.format(args.dataset)),
        ['batch', 'loss', 'acc'])

    revision_logger = Logger(
        os.path.join(args.log_path, 'test_config_{}.log'.format(args.dataset)),
        [
            'dataset', 'dataset_size', 'train_test_split', 'n_classes',
            'model', 'model_depth', 'test_batch_size', 'crop_scale',
            'cropped_data_ratio', 'shuffle'
        ])
    revision_logger.log({
        'dataset': args.dataset,
        'dataset_size': args.dataset_size,
        'train_test_split': args.train_test_split,
        'n_classes': args.n_classes,
        'model': args.model,
        'model_depth': args.model_depth,
        'test_batch_size': args.batch_size,
        'crop_scale': args.crop_scale,
        'cropped_data_ratio': args.cropped_data_ratio,
        'shuffle': args.shuffle
    })

    # load the trained model weights
    print('loading checkpoint {}'.format(args.model_path))
    checkpoint = torch.load(args.model_path, map_location=device)
    assert args.arch == checkpoint['arch']
    model.load_state_dict(checkpoint['state_dict'])

    criterion = cross_entropy_loss()

    accuracies = AverageMeter()
    losses = AverageMeter()

    visualizer = TestVisualizer(test_data_loader.classes,
                                test_data_loader.rgb_mean,
                                test_data_loader.rgb_std, 5, args)

    show_misclassified = False
    block = False

    with torch.no_grad():

        misclassified_images_full = []
        misclassified_true_labels_full = []
        misclassified_predictions_full = []

        for i, (inputs, targets) in enumerate(test_data_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, targets)
            losses.update(loss.item(), inputs.size(0))
            acc = calculate_accuracy(outputs, targets)
            accuracies.update(acc, inputs.size(0))

            _, predictions = outputs.topk(1, 1, True)
            fails = predictions.squeeze() != targets
            predictions = predictions.squeeze().cpu().numpy()
            print("Number of misclassifications: {}/{}".format(
                fails.sum(), len(inputs)))

            if show_misclassified:

                misclassified_idxs = [i for i, x in enumerate(fails) if x]
                misclassified_images = [
                    inputs[idx] for idx in misclassified_idxs
                ]
                misclassified_true_labels = [
                    targets[idx] for idx in misclassified_idxs
                ]
                misclassified_predictions = [
                    predictions[idx] for idx in misclassified_idxs
                ]

                misclassified_images_full.extend(misclassified_images)
                misclassified_true_labels_full.extend(
                    misclassified_true_labels)
                misclassified_predictions_full.extend(
                    misclassified_predictions)

            if args.show_test_images and i % args.plot_interval == 0:
                # visualizer = TestVisualizer(test_data_loader.classes, test_data_loader.rgb_mean,
                #                             test_data_loader.rgb_std, 5, False,
                #                             args)
                visualizer.make_grid(inputs, targets, predictions)
                visualizer.show(block)

            test_logger.log({
                'batch': i + 1,
                'loss': losses.avg,
                'acc': accuracies.avg
            })

            print('Batch: [{0}/{1}]\t'
                  'Loss {loss.value:.4f} (avg {loss.avg:.4f})\t'
                  'Acc {acc.value:.3f} (avg {acc.avg:.3f})'.format(
                      i + 1,
                      len(test_data_loader),
                      loss=losses,
                      acc=accuracies))

        if show_misclassified:
            keep_going = True
            n_images = 5
            start = 0
            while keep_going:
                error_visualizer = TestVisualizer(test_data_loader.classes,
                                                  test_data_loader.rgb_mean,
                                                  test_data_loader.rgb_std,
                                                  n_images, args)
                if n_images**2 < len(misclassified_images_full[start:]):
                    images, labels, predictions = misclassified_images_full[start:start+n_images**2], \
                                                  misclassified_true_labels_full[start:start+n_images**2], \
                                                  misclassified_predictions_full[start:start+n_images**2]
                else:
                    images, labels, predictions = misclassified_images_full[start:], \
                                                  misclassified_true_labels_full[start:], \
                                                  misclassified_predictions_full[start:]
                    keep_going = False
                error_visualizer.make_grid(images, labels, predictions)
                error_visualizer.show(True)
                start += n_images**2

    print('Test log written to {}'.format(test_logger.log_file))
Example #3
0
def train():
    # initialize datasets and loaders
    trainsets, valsets, testsets = args['data.train'], args['data.val'], args[
        'data.test']

    train_loaders = []
    num_train_classes = dict()
    kd_weight_annealing = dict()
    for t_indx, trainset in enumerate(trainsets):
        train_loaders.append(
            MetaDatasetBatchReader('train', [trainset],
                                   valsets,
                                   testsets,
                                   batch_size=BATCHSIZES[trainset]))
        num_train_classes[trainset] = train_loaders[t_indx].num_classes(
            'train')
        # setting up knowledge distillation losses weights annealing
        kd_weight_annealing[trainset] = WeightAnnealing(
            T=int(args['train.cosine_anneal_freq'] * KDANNEALING[trainset]))
    val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets)

    # initialize model and optimizer
    model = get_model(list(num_train_classes.values()), args)
    model_name_temp = args['model.name']
    # KL-divergence loss
    criterion_div = DistillKL(T=4)
    # get a MTL model initialized by ImageNet pretrained model and deactivate the pretrained flag
    args['model.pretrained'] = False
    optimizer = get_optimizer(model, args, params=model.get_parameters())
    # adaptors for aligning features between MDL and SDL models
    adaptors = adaptor(num_datasets=len(trainsets),
                       dim_in=512,
                       opt=args['adaptor.opt']).to(device)
    optimizer_adaptor = torch.optim.Adam(adaptors.parameters(),
                                         lr=0.1,
                                         weight_decay=5e-4)

    # loading single domain learning networks
    extractor_domains = trainsets
    dataset_models = DATASET_MODELS_DICT[args['model.backbone']]
    embed_many = get_domain_extractors(extractor_domains, dataset_models, args,
                                       num_train_classes)

    # restoring the last checkpoint
    args['model.name'] = model_name_temp
    checkpointer = CheckPointer(args, model, optimizer=optimizer)
    if os.path.isfile(checkpointer.out_last_ckpt) and args['train.resume']:
        start_iter, best_val_loss, best_val_acc =\
            checkpointer.restore_out_model(ckpt='last')
    else:
        print('No checkpoint restoration')
        best_val_loss = 999999999
        best_val_acc = start_iter = 0

    # define learning rate policy
    if args['train.lr_policy'] == "step":
        lr_manager = UniformStepLR(optimizer, args, start_iter)
        lr_manager_ad = UniformStepLR(optimizer_adaptor, args, start_iter)
    elif "exp_decay" in args['train.lr_policy']:
        lr_manager = ExpDecayLR(optimizer, args, start_iter)
        lr_manager_ad = ExpDecayLR(optimizer_adaptor, args, start_iter)
    elif "cosine" in args['train.lr_policy']:
        lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter)
        lr_manager_ad = CosineAnnealRestartLR(optimizer_adaptor, args,
                                              start_iter)

    # defining the summary writer
    writer = SummaryWriter(checkpointer.out_path)

    # Training loop
    max_iter = args['train.max_iter']
    epoch_loss = {name: [] for name in trainsets}
    epoch_kd_f_loss = {name: [] for name in trainsets}
    epoch_kd_p_loss = {name: [] for name in trainsets}
    epoch_acc = {name: [] for name in trainsets}
    epoch_val_loss = {name: [] for name in valsets}
    epoch_val_acc = {name: [] for name in valsets}
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = False
    with tf.compat.v1.Session(config=config) as session:
        for i in tqdm(range(max_iter)):
            if i < start_iter:
                continue

            optimizer.zero_grad()
            optimizer_adaptor.zero_grad()

            samples = []
            images = dict()
            num_samples = []
            # loading images and labels
            for t_indx, (name, train_loader) in enumerate(
                    zip(trainsets, train_loaders)):
                sample = train_loader.get_train_batch(session)
                samples.append(sample)
                images[name] = sample['images']
                num_samples.append(sample['images'].size(0))

            logits, mtl_features = model.forward(torch.cat(list(
                images.values()),
                                                           dim=0),
                                                 num_samples,
                                                 kd=True)
            stl_features, stl_logits = embed_many(images,
                                                  return_type='list',
                                                  kd=True,
                                                  logits=True)
            mtl_features = adaptors(mtl_features)

            batch_losses, stats_dicts = [], []
            kd_f_losses = 0
            kd_p_losses = 0
            for t_indx, trainset in enumerate(trainsets):
                batch_loss, stats_dict, _ = cross_entropy_loss(
                    logits[t_indx], samples[t_indx]['labels'])
                batch_losses.append(batch_loss * LOSSWEIGHTS[trainset])
                stats_dicts.append(stats_dict)
                batch_dataset = samples[t_indx]['dataset_name']
                epoch_loss[batch_dataset].append(stats_dict['loss'])
                epoch_acc[batch_dataset].append(stats_dict['acc'])
                ft, fs = torch.nn.functional.normalize(
                    stl_features[t_indx], p=2, dim=1,
                    eps=1e-12), torch.nn.functional.normalize(
                        mtl_features[t_indx], p=2, dim=1, eps=1e-12)
                kd_f_losses_ = distillation_loss(fs,
                                                 ft.detach(),
                                                 opt='kernelcka')
                kd_p_losses_ = criterion_div(logits[t_indx],
                                             stl_logits[t_indx])
                kd_weight = kd_weight_annealing[trainset](
                    t=i, opt='linear') * KDFLOSSWEIGHTS[trainset]
                bam_weight = kd_weight_annealing[trainset](
                    t=i, opt='linear') * KDPLOSSWEIGHTS[trainset]
                if kd_weight > 0:
                    kd_f_losses = kd_f_losses + kd_f_losses_ * kd_weight
                if bam_weight > 0:
                    kd_p_losses = kd_p_losses + kd_p_losses_ * bam_weight
                epoch_kd_f_loss[batch_dataset].append(kd_f_losses_.item())
                epoch_kd_p_loss[batch_dataset].append(kd_p_losses_.item())

            batch_loss = torch.stack(batch_losses).sum()
            kd_f_loss = kd_f_losses * args['train.sigma']
            kd_p_loss = kd_p_losses * args['train.beta']
            batch_loss = batch_loss + kd_f_loss + kd_p_loss

            batch_loss.backward()
            optimizer.step()
            optimizer_adaptor.step()
            lr_manager.step(i)
            lr_manager_ad.step(i)

            if (i + 1) % 200 == 0:
                for dataset_name in trainsets:
                    writer.add_scalar(f"loss/{dataset_name}-train_loss",
                                      np.mean(epoch_loss[dataset_name]), i)
                    writer.add_scalar(f"accuracy/{dataset_name}-train_acc",
                                      np.mean(epoch_acc[dataset_name]), i)
                    writer.add_scalar(
                        f"kd_f_loss/{dataset_name}-train_kd_f_loss",
                        np.mean(epoch_kd_f_loss[dataset_name]), i)
                    writer.add_scalar(
                        f"kd_p_loss/{dataset_name}-train_kd_p_loss",
                        np.mean(epoch_kd_p_loss[dataset_name]), i)
                    epoch_loss[dataset_name], epoch_acc[
                        dataset_name], epoch_kd_f_loss[
                            dataset_name], epoch_kd_p_loss[
                                dataset_name] = [], [], [], []

                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], i)

            # Evaluation inside the training loop
            if (i + 1) % args['train.eval_freq'] == 0:
                model.eval()
                dataset_accs, dataset_losses = [], []
                for valset in valsets:
                    val_losses, val_accs = [], []
                    for j in tqdm(range(args['train.eval_size'])):
                        with torch.no_grad():
                            sample = val_loader.get_validation_task(
                                session, valset)
                            context_features = model.embed(
                                sample['context_images'])
                            target_features = model.embed(
                                sample['target_images'])
                            context_labels = sample['context_labels']
                            target_labels = sample['target_labels']
                            _, stats_dict, _ = prototype_loss(
                                context_features, context_labels,
                                target_features, target_labels)
                        val_losses.append(stats_dict['loss'])
                        val_accs.append(stats_dict['acc'])

                    # write summaries per validation set
                    dataset_acc, dataset_loss = np.mean(
                        val_accs) * 100, np.mean(val_losses)
                    dataset_accs.append(dataset_acc)
                    dataset_losses.append(dataset_loss)
                    epoch_val_loss[valset].append(dataset_loss)
                    epoch_val_acc[valset].append(dataset_acc)
                    writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss,
                                      i)
                    writer.add_scalar(f"accuracy/{valset}/val_acc",
                                      dataset_acc, i)
                    print(
                        f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}"
                    )

                # write summaries averaged over datasets
                avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean(
                    dataset_accs)
                writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i)
                writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i)

                # saving checkpoints
                if avg_val_acc > best_val_acc:
                    best_val_loss, best_val_acc = avg_val_loss, avg_val_acc
                    is_best = True
                    print('Best model so far!')
                else:
                    is_best = False
                extra_dict = {
                    'epoch_loss': epoch_loss,
                    'epoch_acc': epoch_acc,
                    'epoch_val_loss': epoch_val_loss,
                    'epoch_val_acc': epoch_val_acc,
                    'adaptors': adaptors.state_dict(),
                    'optimizer_adaptor': optimizer_adaptor.state_dict()
                }
                checkpointer.save_checkpoint(i,
                                             best_val_acc,
                                             best_val_loss,
                                             is_best,
                                             optimizer=optimizer,
                                             state_dict=model.get_state_dict(),
                                             extra=extra_dict)

                model.train()
                print(f"Trained and evaluated at {i}")

    writer.close()
    if start_iter < max_iter:
        print(
            f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%"""
        )
    else:
        print(
            f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}"""
        )
Example #4
0
def main(args):

    # args.n_epochs = 10
    # args.crop_scale = 0.3
    # args.batch_size = 128
    args.normal_data_ratio = 0.9

    if args.dataset == 'MNIST':
        args.sample_size = 28
    elif args.dataset == 'SOP' or 'Shopee':
        if args.model == 'resnet':
            args.sample_size = 224
        elif args.model == 'vgg' or args.model == 'vgg_attn':
            args.sample_size = 224
        elif args.model == 'inception':
            args.sample_size = 299
        else:
            args.sample_size = 224

    spatial_transform_train = get_train_transform(args)
    crop_transform = get_crop_transform(args)

    if args.dataset == 'MNIST':
        train_data_loader = CroppedMNISTLoader(args, crop_transform=crop_transform,
                                               spatial_transform=spatial_transform_train, training=True)

    elif args.dataset == 'SOP':
        train_data_loader = CroppedSOPLoader(args, crop_transform=crop_transform,
                                               spatial_transform=spatial_transform_train, training=True)

    elif args.dataset == 'Shopee':
        train_data_loader = ShopeeDataLoader(args, crop_transform=crop_transform,
                                               spatial_transform=spatial_transform_train, training=True)

    valid_data_loader = train_data_loader.split_validation()

    args.n_channels = train_data_loader.n_channels
    args.n_classes = train_data_loader.n_classes

    model, parameters = get_model(args)
    model = model.to(device)

    criterion = losses.cross_entropy_loss()

    train_logger = Logger(
        os.path.join(args.log_path, 'train.log'),
        ['epoch', 'loss', 'acc', 'lr'])
    train_batch_logger = Logger(
        os.path.join(args.log_path, 'train_batch.log'),
        ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
    valid_logger = Logger(
        os.path.join(args.log_path, 'val.log'),
        ['epoch', 'loss', 'acc'])

    revision_logger = Logger(
        os.path.join(args.log_path, 'revision_info.log'),
        ['dataset', 'dataset_size', 'train_test_split', 'model', 'model_depth', 'resume', 'resume_path', 'batch_size',
         'n_epochs', 'sample_size', 'crop_scale', 'crop_transform', 'cropped_data_ratio'])
    revision_logger.log({
        'dataset': args.dataset,
        'dataset_size': args.dataset_size,
        'train_test_split': args.train_test_split,
        'model': args.model,
        'model_depth': args.model_depth,
        'resume': args.resume,
        'resume_path': args.resume_path,
        'batch_size': args.batch_size,
        'n_epochs': args.n_epochs,
        'sample_size': args.sample_size,
        'crop_scale': args.crop_scale,
        'crop_transform': crop_transform.__class__.__name__,
        'cropped_data_ratio': args.cropped_data_ratio
    })

    if args.nesterov:
        dampening = 0
    else:
        dampening = args.dampening

    optimizer = optim.SGD(
        parameters,
        lr=args.learning_rate,
        momentum=args.momentum,
        dampening=dampening,
        weight_decay=args.weight_decay,
        nesterov=args.nesterov)

    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=args.lr_patience)

    trainer = Trainer(model, criterion, optimizer, args, device, train_data_loader, lr_scheduler=scheduler,
                      valid_data_loader=valid_data_loader, train_logger=train_logger, batch_logger=train_batch_logger, valid_logger=valid_logger)

    trainer.train()