def main():
    global args, best_prec1
    args = parser.parse_args()

    # create model
    num_classes = 1000
    if args.model == 'inception_resnet_v2' or args.model == 'inception_v4':
        num_classes = 1001
    model = create_model(args.model, num_classes=num_classes, pretrained=args.pretrained)

    optimizer = torch.optim.SGD(
        model.parameters(), args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                args.start_epoch = checkpoint['epoch']
                sparse_checkpoint = False
                if 'sparse' in checkpoint and checkpoint['sparse']:
                    print("Loading sparse model")
                    sparse_checkpoint = True
                    dense_sparse_dense.sparsify(model, sparsity=0.)
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                if args.sparse and not sparse_checkpoint:
                    print("Sparsifying loaded model")
                    dense_sparse_dense.sparsify(model, sparsity=0.5)
                elif sparse_checkpoint and not args.sparse:
                    print("Densifying loaded model")
                    dense_sparse_dense.densify(model)
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                # load from a non-training state dict only checkpoint
                model.load_state_dict(checkpoint)
                print("=> loaded checkpoint '{}'".format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(-1)
    else:
        if args.sparse:
            print("Sparsifying model")
            dense_sparse_dense.sparsify(model, sparsity=0.5)

    model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # Data loading code
    train_dir = os.path.join(args.data, 'train')
    val_dir = os.path.join(args.data, 'validation')
    if 'inception' in args.model:
        normalize = LeNormalize()
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(train_dir, transforms.Compose([
            transforms.RandomSizedCrop(args.img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(val_dir, transforms.Compose([
            transforms.Scale(int(math.floor(args.img_size/0.875))),
            transforms.CenterCrop(args.img_size),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, decay_epochs=args.decay_epochs)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.model,
            'state_dict': model.state_dict(),
            'sparse': args.sparse,
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
            },
            is_best,
            filename='checkpoint-%d.pth.tar' % epoch)
def main():
    global args, best_prec1
    args = parser.parse_args()

    # create model
    num_classes = 1000
    if 'inception' in args.model:
        num_classes = 1001
    model = create_model(args.model,
                         num_classes=num_classes,
                         pretrained=args.pretrained)

    print('Model %s created, param count: %d' %
          (args.model, sum([m.numel() for m in model.parameters()])))

    # optionally resume from a checkpoint
    if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
        print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
        checkpoint = torch.load(args.restore_checkpoint)
        if 'sparse' in checkpoint and checkpoint['sparse']:
            print("Loading sparse model")
            dense_sparse_dense.sparsify(
                model,
                sparsity=0.)  # ensure sparsity_masks exist in model definition
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
        print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
    elif not args.pretrained:
        print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
        exit(-1)

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # Data loading code
    scale_size = int(math.floor(args.img_size / 0.875))
    if 'inception' in args.model:
        normalize = LeNormalize()
        scale_size = args.img_size
    elif 'dpn' in args.model:
        if args.img_size != 224:
            scale_size = args.img_size
        normalize = transforms.Normalize(
            mean=[124 / 255, 117 / 255, 104 / 255],
            std=[1 / (.0167 * 255)] * 3)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    dataset = datasets.ImageFolder(
        args.data,
        transforms.Compose([
            transforms.Scale(scale_size, Image.BICUBIC),
            transforms.CenterCrop(args.img_size),
            transforms.ToTensor(),
            normalize,
        ]))

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(loader):
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input, volatile=True).cuda()
        target_var = torch.autograd.Variable(target, volatile=True).cuda()

        # compute output
        output = model(input_var)
        if num_classes == 1001:
            output = output[:, 1:]
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      i,
                      len(loader),
                      batch_time=batch_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1,
                                                                  top5=top5))
def main():
    args = parser.parse_args()

    batch_size = args.batch_size
    img_size = (args.img_size, args.img_size)
    num_classes = 17
    if args.tif:
        img_type = '.tif'
    else:
        img_type = '.jpg'

    dataset = AmazonDataset(
        args.data,
        train=False,
        multi_label=args.multi_label,
        tags_type='all',
        img_type=img_type,
        img_size=img_size,
        test_aug=args.tta,
    )

    tags = get_tags()
    output_col = ['image_name'] + tags
    submission_col = ['image_name', 'tags']

    loader = data.DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=args.num_processes)

    model = create_model(args.model,
                         pretrained=False,
                         num_classes=num_classes,
                         global_pool=args.gp)

    if not args.no_cuda:
        if args.num_gpu > 1:
            model = torch.nn.DataParallel(model,
                                          device_ids=list(range(
                                              args.num_gpu))).cuda()
        else:
            model.cuda()

    if args.restore_checkpoint is not None:
        assert os.path.isfile(
            args.restore_checkpoint), '%s not found' % args.restore_checkpoint
        checkpoint = torch.load(args.restore_checkpoint)
        print('Restoring model with %s architecture...' % checkpoint['arch'])
        sparse_checkpoint = True if 'sparse' in checkpoint and checkpoint[
            'sparse'] else False
        if sparse_checkpoint:
            print("Loading sparse model")
            dense_sparse_dense.sparsify(
                model,
                sparsity=0.)  # ensure sparsity_masks exist in model definition
        model.load_state_dict(checkpoint['state_dict'])
        if 'args' in checkpoint:
            train_args = checkpoint['args']
        if 'threshold' in checkpoint:
            threshold = checkpoint['threshold']
            threshold = torch.FloatTensor(threshold)
            print('Using thresholds:', threshold)
            if not args.no_cuda:
                threshold = threshold.cuda()
        else:
            threshold = 0.5
        if 'gp' in checkpoint and checkpoint['gp'] != args.gp:
            print(
                "Warning: Model created with global pooling (%s) different from checkpoint (%s)"
                % (args.gp, checkpoint['gp']))
        csplit = os.path.normpath(
            args.restore_checkpoint).split(sep=os.path.sep)
        if len(csplit) > 1:
            exp_name = csplit[-2] + '-' + csplit[-1].split('.')[0]
        else:
            exp_name = ''
        print('Model restored from file: %s' % args.restore_checkpoint)
    else:
        assert False and "No checkpoint specified"

    if args.output:
        output_base = args.output
    else:
        output_base = './output'
    if not exp_name:
        exp_name = '-'.join([
            args.model,
            str(train_args.img_size), 'f' + str(train_args.fold),
            'tif' if args.tif else 'jpg'
        ])
    output_dir = get_outdir(output_base, 'predictions', exp_name)

    model.eval()

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    results_raw = []
    results_thr = []
    results_sub = []
    try:
        end = time.time()
        for batch_idx, (input, target, index) in enumerate(loader):
            data_time_m.update(time.time() - end)
            if not args.no_cuda:
                input = input.cuda()
            input_var = autograd.Variable(input, volatile=True)
            output = model(input_var)

            # augmentation reduction
            reduce_factor = loader.dataset.get_aug_factor()
            if reduce_factor > 1:
                output.data = output.data.unfold(
                    0, reduce_factor, reduce_factor).mean(dim=2).squeeze(dim=2)
                index = index[0:index.size(0):reduce_factor]

            # output non-linearity and thresholding
            output = torch.sigmoid(output)
            if isinstance(threshold, torch.FloatTensor) or isinstance(
                    threshold, torch.cuda.FloatTensor):
                threshold_m = torch.unsqueeze(threshold,
                                              0).expand_as(output.data)
                output_thr = (output.data > threshold_m).byte()
            else:
                output_thr = (output.data > threshold).byte()

            # move data to CPU and collect
            output = output.cpu().data.numpy()
            output_thr = output_thr.cpu().numpy()
            index = index.cpu().numpy().flatten()
            for i, o, ot in zip(index, output, output_thr):
                #print(dataset.inputs[i], o, ot)
                image_name = os.path.splitext(
                    os.path.basename(dataset.inputs[i]))[0]
                results_raw.append([image_name] + list(o))
                results_thr.append([image_name] + list(ot))
                results_sub.append([image_name] + [vector_to_tags(ot, tags)])
                # end iterating through batch

            batch_time_m.update(time.time() - end)
            if batch_idx % args.log_interval == 0:
                print('Inference: [{}/{} ({:.0f}%)]  '
                      'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                      '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '
                      'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                          batch_idx * len(input),
                          len(loader.sampler),
                          100. * batch_idx / len(loader),
                          batch_time=batch_time_m,
                          rate=input_var.size(0) / batch_time_m.val,
                          rate_avg=input_var.size(0) / batch_time_m.avg,
                          data_time=data_time_m))

            end = time.time()
            #end iterating through dataset
    except KeyboardInterrupt:
        pass
    results_raw_df = pd.DataFrame(results_raw, columns=output_col)
    results_raw_df.to_csv(os.path.join(output_dir, 'results_raw.csv'),
                          index=False)
    results_thr_df = pd.DataFrame(results_thr, columns=output_col)
    results_thr_df.to_csv(os.path.join(output_dir, 'results_thr.csv'),
                          index=False)
    results_sub_df = pd.DataFrame(results_sub, columns=submission_col)
    results_sub_df.to_csv(os.path.join(output_dir, 'submission.csv'),
                          index=False)
Exemple #4
0
def main():
    args = parser.parse_args()

    train_input_root = os.path.join(args.data)
    train_labels_file = './data/labels.csv'
    output_dir = get_outdir('./output', 'eval', datetime.now().strftime("%Y%m%d-%H%M%S"))

    batch_size = args.batch_size
    num_epochs = 1000
    if args.tif:
        img_type = '.tif'
    else:
        img_type = '.jpg'
    img_size = (args.img_size, args.img_size)
    num_classes = get_tags_size(args.labels)
    debug_model = False

    torch.manual_seed(args.seed)

    if args.train:
        dataset_train = AmazonDataset(
            train_input_root,
            train_labels_file,
            train=False,
            train_fold=True,
            tags_type=args.labels,
            multi_label=args.multi_label,
            img_type=img_type,
            img_size=img_size,
            fold=args.fold,
        )

        loader_train = data.DataLoader(
            dataset_train,
            batch_size=batch_size,
            shuffle=False,
            num_workers=args.num_processes
        )

    dataset_eval = AmazonDataset(
        train_input_root,
        train_labels_file,
        train=False,
        tags_type=args.labels,
        multi_label=args.multi_label,
        img_type=img_type,
        img_size=img_size,
        test_aug=args.tta,
        fold=args.fold,
    )

    loader_eval = data.DataLoader(
        dataset_eval,
        batch_size=batch_size,
        shuffle=False,
        num_workers=args.num_processes
    )

    model = create_model(args.model, pretrained=args.pretrained, num_classes=num_classes, global_pool=args.gp)

    if not args.no_cuda:
        if args.num_gpu > 1:
            model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
        else:
            model.cuda()

    if False:
        class_weights = torch.from_numpy(dataset_train.get_class_weights()).float()
        class_weights_norm = class_weights / class_weights.sum()
        if not args.no_cuda:
            class_weights = class_weights.cuda()
            class_weights_norm = class_weights_norm.cuda()
    else:
        class_weights = None
        class_weights_norm = None

    if args.loss.lower() == 'nll':
        #assert not args.multi_label and 'Cannot use crossentropy with multi-label target.'
        loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
    elif args.loss.lower() == 'mlsm':
        assert args.multi_label
        loss_fn = torch.nn.MultiLabelSoftMarginLoss(weight=class_weights)
    else:
        assert False and "Invalid loss function"

    if not args.no_cuda:
        loss_fn = loss_fn.cuda()

    # load a checkpoint
    if args.restore_checkpoint is not None:
        assert os.path.isfile(args.restore_checkpoint), '%s not found' % args.restore_checkpoint
        checkpoint = torch.load(args.restore_checkpoint)
        print('Restoring model with %s architecture...' % checkpoint['arch'])
        sparse_checkpoint = True if 'sparse' in checkpoint and checkpoint['sparse'] else False
        if sparse_checkpoint:
            print("Loading sparse model")
            dense_sparse_dense.sparsify(model, sparsity=0.)  # ensure sparsity_masks exist in model definition
        model.load_state_dict(checkpoint['state_dict'])
        if 'threshold' in checkpoint:
            threshold = checkpoint['threshold']
            threshold = torch.FloatTensor(threshold)
            print('Using thresholds:', threshold)
            if not args.no_cuda:
                threshold = threshold.cuda()
        else:
            threshold = 0.5
        if 'gp' in checkpoint and checkpoint['gp'] != args.gp:
            print("Warning: Model created with global pooling (%s) different from checkpoint (%s)"
                  % (args.gp, checkpoint['gp']))
        print('Model restored from file: %s' % args.restore_checkpoint)
    else:
        assert False and "No checkpoint specified"

    if args.train:
        print('Validating training data...')
        validate(
            model, loader_train, loss_fn, args, threshold, prefix='train', output_dir=output_dir)

    print('Validating validation data...')
    validate(
        model, loader_eval, loss_fn, args, threshold, prefix='eval', output_dir=output_dir)
def main():
    args = parser.parse_args()

    train_input_root = os.path.join(args.data)
    train_labels_file = './data/labels.csv'

    if args.output:
        output_base = args.output
    else:
        output_base = './output'

    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        str(args.img_size),
        'f'+str(args.fold),
        'tif' if args.tif else 'jpg'])
    output_dir = get_outdir(output_base, 'train', exp_name)

    batch_size = args.batch_size
    num_epochs = args.epochs
    img_type = '.tif' if args.tif else '.jpg'
    img_size = (args.img_size, args.img_size)
    num_classes = get_tags_size(args.labels)

    torch.manual_seed(args.seed)

    dataset_train = AmazonDataset(
        train_input_root,
        train_labels_file,
        train=True,
        tags_type=args.labels,
        multi_label=args.multi_label,
        img_type=img_type,
        img_size=img_size,
        fold=args.fold,
    )

    #sampler = WeightedRandomOverSampler(dataset_train.get_sample_weights())

    loader_train = data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        #sampler=sampler,
        num_workers=args.num_processes
    )

    dataset_eval = AmazonDataset(
        train_input_root,
        train_labels_file,
        train=False,
        tags_type=args.labels,
        multi_label=args.multi_label,
        img_type=img_type,
        img_size=img_size,
        test_aug=args.tta,
        fold=args.fold,
    )

    loader_eval = data.DataLoader(
        dataset_eval,
        batch_size=batch_size,
        shuffle=False,
        num_workers=args.num_processes
    )

    model = model_factory.create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=num_classes,
        drop_rate=args.drop,
        global_pool=args.gp)

    if not args.no_cuda:
        if args.num_gpu > 1:
            model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
        else:
            model.cuda()

    if args.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.opt.lower() == 'adam':
        optimizer = optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.opt.lower() == 'adadelta':
        optimizer = optim.Adadelta(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.opt.lower() == 'rmsprop':
        optimizer = optim.RMSprop(
            model.parameters(), lr=args.lr, alpha=0.9, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.opt.lower() == 'yellowfin':
        optimizer = YFOptimizer(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay, clip_thresh=2)
    else:
        assert False and "Invalid optimizer"

    if not args.decay_epochs:
        lr_scheduler = ReduceLROnPlateau(optimizer, patience=8)
    else:
        lr_scheduler = None

    if args.class_weights:
        class_weights = torch.from_numpy(dataset_train.get_class_weights()).float()
        class_weights_norm = class_weights / class_weights.sum()
        if not args.no_cuda:
            class_weights = class_weights.cuda()
            class_weights_norm = class_weights_norm.cuda()
    else:
        class_weights = None
        class_weights_norm = None

    if args.loss.lower() == 'nll':
        #assert not args.multi_label and 'Cannot use crossentropy with multi-label target.'
        loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
    elif args.loss.lower() == 'mlsm':
        assert args.multi_label
        loss_fn = torch.nn.MultiLabelSoftMarginLoss(weight=class_weights)
    else:
        assert False and "Invalid loss function"

    if not args.no_cuda:
        loss_fn = loss_fn.cuda()

    # optionally resume from a checkpoint
    start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            sparse_checkpoint = True if 'sparse' in checkpoint and checkpoint['sparse'] else False
            if sparse_checkpoint:
                print("Loading sparse model")
                dense_sparse_dense.sparsify(model, sparsity=0.)  # ensure sparsity_masks exist in model definition
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
            start_epoch = checkpoint['epoch']
            if args.sparse and not sparse_checkpoint:
                print("Sparsifying loaded model")
                dense_sparse_dense.sparsify(model, sparsity=0.5)
            elif sparse_checkpoint and not args.sparse:
                print("Densifying loaded model")
                dense_sparse_dense.densify(model)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(-1)
    else:
        if args.sparse:
            dense_sparse_dense.sparsify(model, sparsity=0.5)

    use_tensorboard = not args.no_tb and CrayonClient is not None
    if use_tensorboard:
        hostname = '127.0.0.1'
        port = 8889
        host_port = args.tbh.split(':')[:2]
        if len(host_port) == 1:
            hostname = host_port[0]
        elif len(host_port) >= 2:
            hostname, port = host_port[:2]
        try:
            cc = CrayonClient(hostname=hostname, port=port)
            try:
                cc.remove_experiment(exp_name)
            except ValueError:
                pass
            exp = cc.create_experiment(exp_name)
        except Exception as e:
            exp = None
            print("Error (%s) connecting to Tensoboard/Crayon server. Giving up..." % str(e))
    else:
        exp = None

    # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of)
    if not args.resume and args.ft_epochs > 0.:
        if args.opt.lower() == 'adam':
            finetune_optimizer = optim.Adam(
                model.get_fc().parameters(), lr=args.ft_lr, weight_decay=args.weight_decay)
        else:
            finetune_optimizer = optim.SGD(
                model.get_fc().parameters(), lr=args.ft_lr, momentum=args.momentum, weight_decay=args.weight_decay)

        finetune_epochs_int = int(np.ceil(args.ft_epochs))
        finetune_final_batches = int(np.ceil((1 - (finetune_epochs_int - args.ft_epochs)) * len(loader_train)))
        print(finetune_epochs_int, finetune_final_batches)
        for fepoch in range(1, finetune_epochs_int + 1):
            if fepoch == finetune_epochs_int and finetune_final_batches:
                batch_limit = finetune_final_batches
            else:
                batch_limit = 0
            train_epoch(
                fepoch, model, loader_train, finetune_optimizer, loss_fn, args,
                class_weights_norm, output_dir, batch_limit=batch_limit)
            step = fepoch * len(loader_train)
            score, _ = validate(step, model, loader_eval, loss_fn, args, 0.3, output_dir)

    score_metric = 'f2'
    best_loss = None
    best_f2 = None
    threshold = 0.3
    try:
        for epoch in range(start_epoch, num_epochs + 1):
            if args.decay_epochs:
                adjust_learning_rate(optimizer, epoch, initial_lr=args.lr, decay_epochs=args.decay_epochs)

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, loss_fn, args, class_weights_norm, output_dir, exp=exp)

            step = epoch * len(loader_train)
            eval_metrics, latest_threshold = validate(
                step, model, loader_eval, loss_fn, args, threshold, output_dir, exp=exp)

            if lr_scheduler is not None:
                lr_scheduler.step(eval_metrics['eval_loss'])

            rowd = OrderedDict(epoch=epoch)
            rowd.update(train_metrics)
            rowd.update(eval_metrics)
            with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
                dw = csv.DictWriter(cf, fieldnames=rowd.keys())
                if best_loss is None:  # first iteration (epoch == 1 can't be used)
                    dw.writeheader()
                dw.writerow(rowd)

            best = False
            if best_loss is None or eval_metrics['eval_loss'] < best_loss[1]:
                best_loss = (epoch, eval_metrics['eval_loss'])
                if score_metric == 'loss':
                    best = True
            if best_f2 is None or eval_metrics['eval_f2'] > best_f2[1]:
                best_f2 = (epoch, eval_metrics['eval_f2'])
                if score_metric == 'f2':
                    best = True

            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.model,
                'sparse': args.sparse,
                'state_dict':  model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'threshold': latest_threshold,
                'args': args,
                'gp': args.gp,
                },
                is_best=best,
                filename='checkpoint-%d.pth.tar' % epoch,
                output_dir=output_dir)

    except KeyboardInterrupt:
        pass
    print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
    print('*** Best f2: {0} (epoch {1})'.format(best_f2[1], best_f2[0]))