Exemplo n.º 1
0
def main():
    global args, best_loss

    # set run output folder
    model_name = config["model_name"]
    output_dir = config["output_dir"]
    save_dir = os.path.join(output_dir, model_name)
    print(" > Output folder for this run -- {}".format(save_dir))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'plots'))

    # assign Ctrl+C signal handler
    signal.signal(signal.SIGINT, ExperimentalRunCleaner(save_dir))

    # create model
    print(" > Creating model ... !")
    model = MultiColumn(config['num_classes'], cnn_def.Model,
                        int(config["column_units"]))

    # multi GPU setting
    model = torch.nn.DataParallel(model, device_ids).to(device)

    # define optimizer
    lr = config["lr"]
    last_lr = config["last_lr"]
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    lr_decayer = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            'min',
                                                            factor=0.5,
                                                            patience=2,
                                                            verbose=True)

    # optionally resume from a checkpoint
    checkpoint_path = os.path.join(config['output_dir'], config['model_name'],
                                   'model_best.pth.tar')
    if args.resume:
        if os.path.isfile(checkpoint_path):
            print(" > Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(checkpoint_path)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            lr_decayer.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            # for state in optimizer.state.values():
            #     for k, v in state.items():
            #         if isinstance(v, torch.Tensor):
            #             state[k] = v.to(device)
            print(" > Loaded checkpoint '{}' (epoch {})".format(
                checkpoint_path, checkpoint['epoch']))
        else:
            print(" !#! No checkpoint found at '{}'".format(checkpoint_path))
    elif config.get('finetune_from') is not None:
        print(' > Loading checkpoint to finetune')
        finetune_model_name = config['finetune_from']
        checkpoint_path = os.path.join(config['output_dir'],
                                       finetune_model_name,
                                       'model_best.pth.tar')
        checkpoint = torch.load(checkpoint_path)
        model.module.clf_layers = nn.Sequential(
            nn.Linear(model.module.column_units, 174)).to(device)
        model.load_state_dict(checkpoint['state_dict'])
        model.module.clf_layers = nn.Sequential(
            nn.Linear(model.module.column_units,
                      config['num_classes'])).to(device)
        print(" > Loaded checkpoint '{}' (epoch {}))".format(
            checkpoint_path, checkpoint['epoch']))
        # Freeze first 3 blocks
        for param in model.module.conv_column.block1.parameters():
            param.requires_grad = False
        for param in model.module.conv_column.block2.parameters():
            param.requires_grad = False
        for param in model.module.conv_column.block3.parameters():
            param.requires_grad = False

    # define augmentation pipeline
    upscale_size_train = int(config['input_spatial_size'] *
                             config["upscale_factor_train"])
    upscale_size_eval = int(config['input_spatial_size'] *
                            config["upscale_factor_eval"])

    # Random crop videos during training
    transform_train_pre = ComposeMix([
        [RandomRotationVideo(15), "vid"],
        [Scale(upscale_size_train), "img"],
        [RandomCropVideo(config['input_spatial_size']), "vid"],
    ])

    # Center crop videos during evaluation
    transform_eval_pre = ComposeMix([
        [Scale(upscale_size_eval), "img"],
        [torchvision.transforms.ToPILImage(), "img"],
        [
            torchvision.transforms.CenterCrop(config['input_spatial_size']),
            "img"
        ],
    ])

    # Transforms common to train and eval sets and applied after "pre" transforms
    transform_post = ComposeMix([
        [torchvision.transforms.ToTensor(), "img"],
        [
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # default values for imagenet
                std=[0.229, 0.224, 0.225]),
            "img"
        ]
    ])

    train_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_train'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_train'],
        step_size=config['step_size_train'],
        is_val=False,
        transform_pre=transform_train_pre,
        transform_post=transform_post,
        augmentation_mappings_json=config['augmentation_mappings_json'],
        augmentation_types_todo=config['augmentation_types_todo'],
        get_item_id=False,
    )

    print(" > Using {} processes for data loader.".format(
        config["num_workers"]))

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True)

    val_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_val'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_val'],
        step_size=config['step_size_val'],
        is_val=True,
        transform_pre=transform_eval_pre,
        transform_post=transform_post,
        get_item_id=True,
    )

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             pin_memory=True,
                                             drop_last=False)

    test_data = VideoFolder(
        root=config['data_folder'],
        json_file_input=config['json_data_test'],
        json_file_labels=config['json_file_labels'],
        clip_size=config['clip_size'],
        nclips=config['nclips_val'],
        step_size=config['step_size_val'],
        is_val=True,
        transform_pre=transform_eval_pre,
        transform_post=transform_post,
        get_item_id=True,
        is_test=True,
    )

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=False)

    print(" > Number of dataset classes : {}".format(len(train_data.classes)))
    assert len(train_data.classes) == config["num_classes"]

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(device)

    if args.eval_only:
        validate(val_loader, model, criterion, train_data.classes_dict)
        print(" > Evaluation DONE !")
        return

    # set callbacks
    plotter = PlotLearning(os.path.join(save_dir, "plots"),
                           config["num_classes"])
    val_loss = float('Inf')

    # set end condition by num epochs
    num_epochs = int(config["num_epochs"])
    if num_epochs == -1:
        num_epochs = 999999

    print(" > Training is getting started...")
    print(" > Training takes {} epochs.".format(num_epochs))
    start_epoch = args.start_epoch if args.resume else 0

    for epoch in range(start_epoch, num_epochs):

        lrs = [params['lr'] for params in optimizer.param_groups]
        print(" > Current LR(s) -- {}".format(lrs))
        if np.max(lr) < last_lr and last_lr > 0:
            print(" > Training is DONE by learning rate {}".format(last_lr))
            sys.exit(1)

        with experiment.train():
            # train for one epoch
            train_loss, train_top1, train_top5 = train(train_loader, model,
                                                       criterion, optimizer,
                                                       epoch)
            metrics = {
                'avg_loss': train_loss,
                'avg_top1': train_top1,
                'avg_top5': train_top5,
            }
            experiment.log_metrics(metrics)

        with experiment.validate():
            # evaluate on validation set
            val_loss, val_top1, val_top5 = validate(val_loader, model,
                                                    criterion)
            metrics = {
                'avg_loss': val_loss,
                'avg_top1': val_top1,
                'avg_top5': val_top5,
            }
            experiment.log_metrics(metrics)
        experiment.log_metric('epoch', epoch)

        # set learning rate
        lr_decayer.step(val_loss, epoch)

        # plot learning
        plotter_dict = {}
        plotter_dict['loss'] = train_loss
        plotter_dict['val_loss'] = val_loss
        plotter_dict['acc'] = train_top1 / 100
        plotter_dict['val_acc'] = val_top1 / 100
        plotter_dict['learning_rate'] = lr
        plotter.plot(plotter_dict)

        print(" > Validation loss after epoch {} = {}".format(epoch, val_loss))

        # remember best loss and save the checkpoint
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': "Conv4Col",
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': lr_decayer.state_dict(),
                'best_loss': best_loss,
            }, is_best, config)
Exemplo n.º 2
0
def main():
    global args, best_prec1

    # set run output folder
    model_name = config["model_name"]
    output_dir = config["output_dir"]
    print("=> Output folder for this run -- {}".format(model_name))
    save_dir = os.path.join(output_dir, model_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'plots'))

    # adds a handler for Ctrl+C
    def signal_handler(signal, frame):
        """
        Remove the output dir, if you exit with Ctrl+C and
        if there are less then 3 files.
        It prevents the noise of experimental runs.
        """
        num_files = len(glob.glob(save_dir + "/*"))
        if num_files < 1:
            shutil.rmtree(save_dir)
        print('You pressed Ctrl+C!')
        sys.exit(0)
    # assign Ctrl+C signal handler
    signal.signal(signal.SIGINT, signal_handler)

    # create model
    model = ConvColumn(config['num_classes'])

    # multi GPU setting
    if args.use_gpu:
        model = torch.nn.DataParallel(model, device_ids=gpus).to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(config['checkpoint']):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(config['checkpoint'])
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(config['checkpoint'], checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(
                config['checkpoint']))

    transform = Compose([
        CenterCrop(84),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225])
    ])

    train_data = VideoFolder(root=config['train_data_folder'],
                             csv_file_input=config['train_data_csv'],
                             csv_file_labels=config['labels_csv'],
                             clip_size=config['clip_size'],
                             nclips=1,
                             step_size=config['step_size'],
                             is_val=False,
                             transform=transform,
                             )

    print(" > Using {} processes for data loader.".format(
        config["num_workers"]))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config['batch_size'], shuffle=True,
        num_workers=config['num_workers'], pin_memory=True,
        drop_last=True)

    val_data = VideoFolder(root=config['val_data_folder'],
                           csv_file_input=config['val_data_csv'],
                           csv_file_labels=config['labels_csv'],
                           clip_size=config['clip_size'],
                           nclips=1,
                           step_size=config['step_size'],
                           is_val=True,
                           transform=transform,
                           )

    val_loader = torch.utils.data.DataLoader(
        val_data,
        batch_size=config['batch_size'], shuffle=False,
        num_workers=config['num_workers'], pin_memory=True,
        drop_last=False)

    assert len(train_data.classes) == config["num_classes"]

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    # define optimizer
    lr = config["lr"]
    last_lr = config["last_lr"]
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    if args.eval_only:
        validate(val_loader, model, criterion, train_data.classes_dict)
        return

    # set callbacks
    plotter = PlotLearning(os.path.join(
        save_dir, "plots"), config["num_classes"])
    lr_decayer = MonitorLRDecay(0.6, 3)
    val_loss = 9999999

    # set end condition by num epochs
    num_epochs = int(config["num_epochs"])
    if num_epochs == -1:
        num_epochs = 999999

    print(" > Training is getting started...")
    print(" > Training takes {} epochs.".format(num_epochs))
    start_epoch = args.start_epoch if args.resume else 0

    for epoch in range(start_epoch, num_epochs):
        lr = lr_decayer(val_loss, lr)
        print(" > Current LR : {}".format(lr))

        if lr < last_lr and last_lr > 0:
            print(" > Training is done by reaching the last learning rate {}".
                  format(last_lr))
            sys.exit(1)

        # train for one epoch
        train_loss, train_top1, train_top5 = train(
            train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        val_loss, val_top1, val_top5 = validate(val_loader, model, criterion)

        # plot learning
        plotter_dict = {}
        plotter_dict['loss'] = train_loss
        plotter_dict['val_loss'] = val_loss
        plotter_dict['acc'] = train_top1
        plotter_dict['val_acc'] = val_top1
        plotter_dict['learning_rate'] = lr
        plotter.plot(plotter_dict)

        # remember best prec@1 and save checkpoint
        is_best = val_top1 > best_prec1
        best_prec1 = max(val_top1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': "Conv4Col",
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best, config)
def trainEnsemble():
    global args, best_prec1

    # set run output folder
    model_name = "classifier"
    output_dir = config["output_dir"]

    save_dir = os.path.join(output_dir, model_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'plots'))

    # adds a handler for Ctrl+C
    def signal_handler(signal, frame):
        """
        Remove the output dir, if you exit with Ctrl+C and
        if there are less then 3 files.
        It prevents the noise of experimental runs.
        """
        num_files = len(glob.glob(save_dir + "/*"))
        if num_files < 1:
            shutil.rmtree(save_dir)
        print('You pressed Ctrl+C!')
        sys.exit(0)

    # assign Ctrl+C signal handler
    signal.signal(signal.SIGINT, signal_handler)

    # create model
    #model = ConvColumn(config['num_classes'])

    model0 = ConvColumn6(config['num_classes'])
    model0 = torch.nn.DataParallel(model0, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv6/checkpoint.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv6/checkpoint.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model0.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv6/checkpoint.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model1 = ConvColumn7(config['num_classes'])
    model1 = torch.nn.DataParallel(model1, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv7/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv7/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model1.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv7/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    classifier = Classifier(config['num_classes'])
    classifier = torch.nn.DataParallel(classifier, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/classifier/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/classifier/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        classifier.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/classifier/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model3 = ConvColumn9(config['num_classes'])
    model3 = torch.nn.DataParallel(model3, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv9/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv9/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model3.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv9/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model2 = ConvColumn8(config['num_classes'])
    model2 = torch.nn.DataParallel(model2, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/jester_conv8/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/jester_conv8/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model2.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/jester_conv8/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    model4 = ConvColumn5(config['num_classes'])
    model4 = torch.nn.DataParallel(model4, device_ids=gpus).to(device)

    if os.path.isfile("trainings/jpeg_model/ConvColumn5/model_best.pth.tar"):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(
            "trainings/jpeg_model/ConvColumn5/model_best.pth.tar")
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model4.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            "trainings/jpeg_model/ConvColumn5/model_best.pth.tar",
            checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(config['checkpoint']))

    transform_train = Compose([
        RandomAffine(degrees=[-10, 10],
                     translate=[0.15, 0.15],
                     scale=[0.9, 1.1],
                     shear=[-5, 5]),
        CenterCrop(84),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    transform_valid = Compose([
        CenterCrop(84),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_data = VideoFolder(
        root=config['train_data_folder'],
        csv_file_input=config['train_data_csv'],
        csv_file_labels=config['labels_csv'],
        clip_size=config['clip_size'],
        nclips=1,
        step_size=config['step_size'],
        is_val=False,
        transform=transform_train,
    )

    print(" > Using {} processes for data loader.".format(
        config["num_workers"]))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True)

    val_data = VideoFolder(
        root=config['val_data_folder'],
        csv_file_input=config['val_data_csv'],
        csv_file_labels=config['labels_csv'],
        clip_size=config['clip_size'],
        nclips=1,
        step_size=config['step_size'],
        is_val=True,
        transform=transform_valid,
    )

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             pin_memory=True,
                                             drop_last=False)

    list_id_files = []
    for i in val_data.csv_data:
        list_id_files.append(i.path[16:])
    print(len(list_id_files))

    ###########

    assert len(train_data.classes) == config["num_classes"]

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().to(device)

    # define optimizer
    lr = config["lr"]
    last_lr = config["last_lr"]
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr, amsgrad=True)

    #torch.optim.SGD(classifier.parameters(), lr,
    #momentum=momentum,
    #weight_decay=weight_decay)

    # set callbacks
    plotter = PlotLearning(os.path.join(save_dir, "plots"),
                           config["num_classes"])
    lr_decayer = MonitorLRDecay(0.6, 3)
    val_loss = 9999999

    # set end condition by num epochs
    num_epochs = int(config["num_epochs"])
    if num_epochs == -1:
        num_epochs = 999999

    if args.test_only:
        print("test")
        test_data = VideoFolder_test(
            root=config['val_data_folder'],
            csv_file_input=config['test_data_csv'],
            clip_size=config['clip_size'],
            nclips=1,
            step_size=config['step_size'],
            is_val=True,
            transform=transform_valid,
        )

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            pin_memory=True,
            drop_last=False)

        list_id_files_test = []
        for i in test_data.csv_data:
            list_id_files_test.append(i.path[16:])
        print(len(list_id_files_test))
        test_ensemble(test_loader, classifier, model1, model2, model3,
                      list_id_files_test, criterion, train_data.classes_dict)
        return

    if args.eval_only:
        val_loss, val_top1, val_top5 = validate_ensemble(
            val_loader, classifier, model1, model2, model3, list_id_files,
            criterion, train_data.classes_dict)
        return

    # switch to evaluate mode
    model0.eval()
    model1.eval()
    model2.eval()
    model3.eval()
    model4.eval()
    classifier.train()

    logits_matrix = []
    targets_list = []

    new_input = np.array([])
    train_writer = tensorboardX.SummaryWriter("logs")

    for epoch in range(0, num_epochs):
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        lr = lr_decayer(val_loss, lr)
        print(" > Current LR : {}".format(lr))

        if lr < last_lr and last_lr > 0:
            print(" > Training is done by reaching the last learning rate {}".
                  format(last_lr))
            sys.exit(1)
        for i, (input, target) in enumerate(train_loader):
            input, target = input.to(device), target.to(device)

            with torch.no_grad():

                # compute output and loss
                output0, feature0 = model0(input)
                output1, feature1 = model1(input)
                output2, feature2 = model2(input)
                output3, feature3 = model3(input)
                output4, feature4 = model4(input)
                #sav=torch.cat((feature0,feature1,feature2,feature3,feature4),1)
                sav = torch.cat((output0, output1, output2, output3, output4),
                                1)
            classifier.zero_grad()
            class_video = classifier(sav)
            loss = criterion(class_video, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(class_video.detach(),
                                    target.detach().cpu(),
                                    topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % config["print_freq"] == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          0,
                          i,
                          len(train_loader),
                          loss=losses,
                          top1=top1,
                          top5=top5))

        val_loss, val_top1, val_top5 = validate_ensemble(
            val_loader, classifier, model0, model1, model2, model3, model4,
            list_id_files, criterion)

        train_writer.add_scalar('loss', loss, losses.avg)
        train_writer.add_scalar('top1', top1.avg, epoch + 1)
        train_writer.add_scalar('top5', top5.avg, epoch + 1)

        train_writer.add_scalar('val_loss', val_loss, epoch + 1)
        train_writer.add_scalar('val_top1', val_top1, epoch + 1)
        train_writer.add_scalar('val_top5', val_top5, epoch + 1)

        # remember best prec@1 and save checkpoint
        is_best = val_top1 > best_prec1
        best_prec1 = max(val_top1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': "Classifier",
                'state_dict': classifier.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, config)