コード例 #1
0
def main():
    global global_epoch_confusion
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir',
                        type=str,
                        default='log',
                        help='path for saving trained models and log info')
    parser.add_argument('--ann_dir',
                        type=str,
                        default='/media/data/dataset/coco/annotations',
                        help='path for annotation json file')
    parser.add_argument('--image_dir', default='/media/data/dataset/coco')

    parser.add_argument('--resume',
                        default=1,
                        type=int,
                        help='whether to resume from log_dir if existent')
    parser.add_argument('--finetune', default=0, type=int)
    parser.add_argument('--num_epochs', type=int, default=20)
    parser.add_argument('--start_epoch', type=int, default=1)
    parser.add_argument('--batch_size', type=int,
                        default=64)  # batch size should be smaller if use text
    parser.add_argument('--crop_size', type=int, default=224)
    parser.add_argument('--image_size', type=int, default=256)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=0.1)
    parser.add_argument('--lam',
                        default=0.5,
                        type=float,
                        help='hyperparameter lambda')
    parser.add_argument('--first',
                        default="person",
                        type=str,
                        help='first object index')
    parser.add_argument('--second',
                        default="clock",
                        type=str,
                        help='second object index')
    parser.add_argument('--third',
                        default="bus",
                        type=str,
                        help='third object index')
    parser.add_argument('--pretrained',
                        default='/set/your/model/path',
                        type=str,
                        metavar='PATH')
    parser.add_argument('--debug',
                        help='Check model accuracy',
                        action='store_true')
    parser.add_argument('--weight',
                        default=1,
                        type=float,
                        help='oversampling weight')
    parser.add_argument('--target_weight',
                        default=0,
                        type=float,
                        help='target_weight')
    parser.add_argument('--class_num',
                        default=81,
                        type=int,
                        help='81:coco_gender;80:coco')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if os.path.exists(args.log_dir) and not args.resume:
        print('Path {} exists! and not resuming'.format(args.log_dir))
        return
    if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)

    #save all parameters for training
    with open(os.path.join(args.log_dir, "arguments.log"), "a") as f:
        f.write(str(args) + '\n')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # Image preprocessing
    train_transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(), normalize
    ])

    # Data samplers.
    train_data = CocoObject(ann_dir=args.ann_dir,
                            image_dir=args.image_dir,
                            split='train',
                            transform=train_transform)

    val_data = CocoObject(ann_dir=args.ann_dir,
                          image_dir=args.image_dir,
                          split='val',
                          transform=val_transform)
    object2id = val_data.object2id

    first_id = object2id[args.first]
    second_id = object2id[args.second]
    third_id = object2id[args.third]

    weights = [
        args.weight if first_id in train_data.labels[i] or second_id
        in train_data.labels[i] or third_id in train_data.labels[i] else 1.0
        for i in range(len(train_data.labels))
    ]
    sampler = WeightedRandomSampler(torch.DoubleTensor(weights),
                                    len(train_data.labels))

    # Data loaders / batch assemblers.
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               num_workers=1,
                                               pin_memory=True,
                                               sampler=sampler)

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=0,
                                             pin_memory=True)

    # Build the models
    model = MultilabelObject(args, args.class_num).cuda()
    criterion = nn.BCEWithLogitsLoss(weight=torch.FloatTensor(
        train_data.getObjectWeights()),
                                     size_average=True,
                                     reduction='None').cuda()

    def trainable_params():
        for param in model.parameters():
            if param.requires_grad:
                yield param

    optimizer = torch.optim.Adam(trainable_params(),
                                 args.learning_rate,
                                 weight_decay=1e-5)

    best_performance = 0
    if os.path.isfile(args.pretrained):
        train_F = open(os.path.join(args.log_dir, 'train.csv'), 'w')
        val_F = open(os.path.join(args.log_dir, 'val.csv'), 'w')
        score_F = open(os.path.join(args.log_dir, 'score.csv'), 'w')
        print("=> loading checkpoint '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        args.start_epoch = checkpoint['epoch']
        best_performance = checkpoint['best_performance']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        exit()

    for epoch in range(args.start_epoch, args.num_epochs + 1):
        global_epoch_confusion.append({})
        adjust_learning_rate(optimizer, epoch)
        train(args, epoch, model, criterion, train_loader, optimizer, train_F,
              score_F, train_data, object2id)
        current_performance = get_confusion(args, epoch, model, criterion,
                                            val_loader, optimizer, val_F,
                                            score_F, val_data)
        is_best = current_performance > best_performance
        best_performance = max(current_performance, best_performance)
        model_state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_performance': best_performance
        }
        save_checkpoint(args, model_state, is_best,
                        os.path.join(args.log_dir, 'checkpoint.pth.tar'))
        confusion_matrix = global_epoch_confusion[-1]["confusion"]
        first_second = compute_confusion(confusion_matrix, args.first,
                                         args.second)
        first_third = compute_confusion(confusion_matrix, args.first,
                                        args.third)
        print(
            str((args.first, args.second, args.third)) + " triplet: " + str(
                compute_bias(confusion_matrix, args.first, args.second,
                             args.third)))
        print(str((args.first, args.second)) + ": " + str(first_second))
        print(str((args.first, args.third)) + ": " + str(first_third))
        #os.system('python plot.py {} &'.format(args.log_dir))

    train_F.close()
    val_F.close()
    score_F.close()
    np.save(os.path.join(args.log_dir, 'global_epoch_confusion.npy'),
            global_epoch_confusion)
コード例 #2
0
ファイル: infer_batch.py プロジェクト: yqtianust/ASL
def test_with_loader():

    # model, input_size, threshold, num_classes, classes_list = load_model("XL")
    model, input_size, threshold, num_classes, classes_list = load_model("L")

    from torchvision.transforms import transforms
    from data_loader import CocoObject
    from torch.autograd import Variable
    from sklearn.metrics import average_precision_score
    import torch.nn as nn
    from tqdm import tqdm as tqdm
    # crop_size = 224
    # image_size = 256
    # batch_size = 4
    batch_size = 12
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                  std=[0.229, 0.224, 0.225])

    val_transform = transforms.Compose([
        transforms.Resize([input_size, input_size]),
        # transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
    ])

    # Data samplers.
    ann_dir = '/home/ytianas/EMSE_COCO/cocodataset/annotations'
    image_dir = '/home/ytianas/EMSE_COCO/cocodataset/'
    test_data = CocoObject(ann_dir=ann_dir,
                           image_dir=image_dir,
                           split='test',
                           transform=val_transform)
    image_ids = test_data.image_ids
    image_path_map = test_data.image_path_map
    # 80 objects
    id2object = test_data.id2object
    id2labels = test_data.id2labels
    # Data loaders / batch assemblers.
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)
    count = 0
    yhats = []
    labels = []
    imagefiles = []
    res = list()

    t = tqdm(test_loader, desc='testing')

    for batch_idx, (images, objects, image_ids) in enumerate(t):

        images = Variable(images).cuda()
        objects = Variable(objects).cuda()

        # print(images.shape)

        object_preds = model(images)
        m = nn.Sigmoid()
        object_preds_r = m(object_preds)

        count = count + len(image_ids)
        for i in range(len(image_ids)):
            image_file_name = image_path_map[int(image_ids[i])]
            yhat = []
            label = id2labels[int(image_ids[i])]

            for j in range(len(object_preds[i])):
                a = object_preds_r[i][j].cpu().data.numpy()
                if a > threshold:
                    yhat.append(id2object[j])

            yhats.append(yhat)
            labels.append(label)
            imagefiles.append(image_file_name)

        res.append((image_ids, object_preds.data.cpu(), objects.data.cpu()))
        if count % 1000 == 0:
            print("count: " + str(count))

    preds_object = torch.cat([entry[1] for entry in res], 0)
    targets_object = torch.cat([entry[2] for entry in res], 0)
    eval_score_object = average_precision_score(targets_object.numpy(),
                                                preds_object.numpy())
    print('\nmean average precision of object classifier on test data is {}\n'.
          format(eval_score_object))
コード例 #3
0
def main():
    global global_epoch_confusion
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir',
                        type=str,
                        default='log',
                        help='path for saving trained models and log info')
    parser.add_argument('--ann_dir',
                        type=str,
                        default='/media/data/dataset/coco/annotations',
                        help='path for annotation json file')
    parser.add_argument('--image_dir', default='/media/data/dataset/coco')

    parser.add_argument('--resume',
                        default=1,
                        type=int,
                        help='whether to resume from log_dir if existent')
    parser.add_argument('--finetune', default=0, type=int)
    parser.add_argument('--num_epochs', type=int, default=20)
    parser.add_argument('--start_epoch', type=int, default=1)
    parser.add_argument('--batch_size', type=int,
                        default=64)  # batch size should be smaller if use text
    parser.add_argument('--crop_size', type=int, default=224)
    parser.add_argument('--image_size', type=int, default=256)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=0.1)
    parser.add_argument('--lam',
                        default=0.5,
                        type=float,
                        help='hyperparameter lambda')
    parser.add_argument('--first',
                        default="person",
                        type=str,
                        help='first object index')
    parser.add_argument('--second',
                        default="bus",
                        type=str,
                        help='second object index')
    parser.add_argument('--pretrained',
                        default='/set/your/model/path',
                        type=str,
                        metavar='PATH')
    parser.add_argument('--debug',
                        help='Check model accuracy',
                        action='store_true')
    parser.add_argument('--ratio',
                        default=0.5,
                        type=float,
                        help='target ratio for batchnorm layers')
    parser.add_argument('--replace',
                        help='replace bn layer ',
                        action='store_true')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if os.path.exists(args.log_dir) and not args.resume:
        print('Path {} exists! and not resuming'.format(args.log_dir))
        return
    if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)

    #save all parameters for training
    with open(os.path.join(args.log_dir, "arguments.log"), "a") as f:
        f.write(str(args) + '\n')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # Image preprocessing
    train_transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(), normalize
    ])
    # Data samplers.
    train_data = CocoObject(ann_dir=args.ann_dir,
                            image_dir=args.image_dir,
                            split='train',
                            transform=train_transform)
    val_data = CocoObject(ann_dir=args.ann_dir,
                          image_dir=args.image_dir,
                          split='val',
                          transform=val_transform)
    first_data = CocoObject(ann_dir=args.ann_dir,
                            image_dir=args.image_dir,
                            split='train',
                            transform=train_transform,
                            filter=args.first)
    second_data = CocoObject(ann_dir=args.ann_dir,
                             image_dir=args.image_dir,
                             split='train',
                             transform=train_transform,
                             filter=args.second)

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=0,
                                             pin_memory=True)
    # Build the models
    model = MultilabelObject(args, 80).cuda()

    criterion = nn.BCEWithLogitsLoss(weight=torch.FloatTensor(
        train_data.getObjectWeights()),
                                     size_average=True,
                                     reduction='None').cuda()

    def trainable_params():
        for param in model.parameters():
            if param.requires_grad:
                yield param

    optimizer = torch.optim.Adam(trainable_params(),
                                 args.learning_rate,
                                 weight_decay=1e-5)

    best_performance = 0
    if os.path.isfile(args.pretrained):
        train_F = open(os.path.join(args.log_dir, 'train.csv'), 'w')
        val_F = open(os.path.join(args.log_dir, 'val.csv'), 'w')
        score_F = open(os.path.join(args.log_dir, 'score.csv'), 'w')
        print("=> loading checkpoint '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        args.start_epoch = checkpoint['epoch']
        best_performance = checkpoint['best_performance']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

        current_performance = get_confusion(args, epoch, model, criterion,
                                            val_loader, optimizer, val_F,
                                            score_F, val_data)
        confusion_matrix = global_epoch_confusion[-1]["confusion"]
        first_second = compute_confusion(confusion_matrix, args.first,
                                         args.second)
        first_third = compute_confusion(confusion_matrix, args.first,
                                        args.third)
        print(
            str((args.first, args.second, args.third)) + " triplet: " + str(
                compute_bias(confusion_matrix, args.first, args.second,
                             args.third)))
        print(str((args.first, args.second)) + ": " + str(first_second))
        print(str((args.first, args.third)) + ": " + str(first_third))

    train_F.close()
    val_F.close()
    score_F.close()
コード例 #4
0
def get_yhats_train(confidence=0.5):
    crop_size = 224
    image_size = 256
    batch_size = 16
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_transform = transforms.Compose([
        transforms.Scale(image_size),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(), normalize
    ])

    # Data samplers.
    train_data = CocoObject(ann_dir=ann_dir,
                            image_dir=image_dir,
                            split='train',
                            transform=val_transform)
    image_ids = train_data.new_image_ids
    image_path_map = train_data.image_path_map
    #80 objects
    id2object = train_data.id2object
    id2labels = train_data.id2labels
    # Data loaders / batch assemblers.
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=4,
                                               pin_memory=True)
    model = MultilabelObject(None, 81).cuda()

    log_dir = "./"
    checkpoint = torch.load(os.path.join(log_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    t = tqdm(train_loader, desc='Activation')
    count = 0
    yhats = []
    labels = []
    imagefiles = []
    for batch_idx, (images, objects, image_ids) in enumerate(t):

        images = Variable(images).cuda()
        objects = Variable(objects).cuda()
        object_preds = model(images)
        m = nn.Sigmoid()
        object_preds_r = m(object_preds)
        count = count + len(image_ids)
        for i in xrange(len(image_ids)):
            image_file_name = image_path_map[image_ids[i]]
            yhat = []
            label = id2labels[image_ids[i]]

            for j in xrange(len(object_preds[i])):
                a = object_preds_r[i][j].cpu().data.numpy()
                if a[0] > confidence:
                    yhat.append(id2object[j])
            yhats.append(yhat)
            labels.append(label)
            imagefiles.append(image_file_name)
        if count % 1000 == 0:
            print("count: " + str(count))

    with open('globalyhats_train.pickle', 'wb') as handle:
        pickle.dump(yhats, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('globallabels_train.pickle', 'wb') as handle:
        pickle.dump(labels, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('imagefiles_train.pickle', 'wb') as handle:
        pickle.dump(imagefiles, handle, protocol=pickle.HIGHEST_PROTOCOL)
コード例 #5
0
def get_coverage_test():
    global globalcoverage
    ann_dir = '/local/yuchi/dataset/coco/annotations'
    image_dir = '/local/yuchi/dataset/coco/'
    crop_size = 224
    image_size = 256
    batch_size = 16
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_transform = transforms.Compose([
        transforms.Scale(image_size),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(), normalize
    ])

    # Data samplers.
    train_data = CocoObject(ann_dir=ann_dir,
                            image_dir=image_dir,
                            split='test',
                            transform=val_transform)
    image_ids = train_data.new_image_ids
    image_path_map = train_data.image_path_map
    #80 objects
    id2object = train_data.id2object
    id2labels = train_data.id2labels
    # Data loaders / batch assemblers.
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=4,
                                               pin_memory=True)
    model = MultilabelObject(None, 81).cuda()
    hook_all_conv_layer(model, get_channel_coverage_group_exp)
    log_dir = "./"
    log_dir1 = "/home/yuchi/work/coco/backup"
    checkpoint = torch.load(os.path.join(log_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    t = tqdm(train_loader, desc='Activation')
    count = 0
    for batch_idx, (images, objects, image_ids) in enumerate(t):

        images = Variable(images).cuda()
        objects = Variable(objects).cuda()

        for i in xrange(len(image_ids)):
            globalcoverage.append({})
            image_file_name = image_path_map[int(image_ids[i])]
            yhat = []
            '''
            for j in xrange(len(object_preds[i])):
                a = object_preds_r[i][j].cpu().data.numpy()
                if a[0] > 0.5:
                    yhat.append(id2object[j])
            '''
            globalcoverage[-1]["file"] = image_file_name
            globalcoverage[-1]["yhat"] = yhat
            globalcoverage[-1]["dataset"] = "test"
            globalcoverage[-1]["jlabel"] = id2labels[int(image_ids[i])]

        object_preds = model(images)
        m = nn.Sigmoid()
        object_preds_r = m(object_preds)
        count = count + len(image_ids)

        if count % 1000 == 0:
            print("count: " + str(count))
コード例 #6
0
from pycocotools.coco import COCO
# from PIL import Image
from data_loader import CocoObject
import numpy as np
import os
import cv2
from tqdm import tqdm

if __name__ == '__main__':
    ann_dir = '/home/ytianas/EMSE_COCO/cocodataset/annotations'
    image_dir = '/home/ytianas/EMSE_COCO/cocodataset/'
    test_data = CocoObject(ann_dir=ann_dir,
                           image_dir=image_dir,
                           split='val',
                           transform=None)
    image_ids = test_data.image_ids
    image_path_map = test_data.image_path_map
    # 80 objects
    id2object = test_data.id2object
    id2labels = test_data.id2labels

    # print(id2labels)
    # print(id2object)
    # exit(-1)

    ann_cat_name = test_data.ann_cat_name
    ann_cat_id = test_data.ann_cat_id
    bboxes = test_data.bbox
    masks = test_data.mask

    fill_values = [0, 127, 255]
コード例 #7
0
ファイル: train.py プロジェクト: yzx-fish/DeepInspect
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir',
                        type=str,
                        default='',
                        help='path for saving trained models and log info')
    parser.add_argument('--ann_dir',
                        type=str,
                        default='/home/tyc/Downloads/cocodataset/annotations',
                        help='path for annotation json file')
    parser.add_argument('--image_dir',
                        default='/home/tyc/Downloads/cocodataset')

    parser.add_argument('--resume',
                        default=0,
                        type=int,
                        help='whether to resume from log_dir if existent')
    parser.add_argument('--finetune', default=0, type=int)
    parser.add_argument('--num_epochs', type=int, default=50)
    parser.add_argument('--start_epoch', type=int, default=1)
    parser.add_argument('--batch_size', type=int,
                        default=64)  # batch size should be smaller if use text
    parser.add_argument('--crop_size', type=int, default=224)
    parser.add_argument('--image_size', type=int, default=256)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=0.0001)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if os.path.exists(args.log_dir) and not args.resume:
        print('Path {} exists! and not resuming'.format(args.log_dir))
        return
    if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)

    #save all parameters for training
    with open(os.path.join(args.log_dir, "arguments.log"), "a") as f:
        f.write(str(args) + '\n')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # Image preprocessing
    train_transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(), normalize
    ])

    # Data samplers.
    train_data = CocoObject(ann_dir=args.ann_dir,
                            image_dir=args.image_dir,
                            split='train',
                            transform=train_transform)

    val_data = CocoObject(ann_dir=args.ann_dir,
                          image_dir=args.image_dir,
                          split='val',
                          transform=val_transform)

    # Data loaders / batch assemblers.
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=6,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True)

    # Build the models
    model = MultilabelObject(args, 80).cuda()
    criterion = nn.BCEWithLogitsLoss(weight=torch.FloatTensor(
        train_data.getObjectWeights()),
                                     size_average=True).cuda()

    def trainable_params():
        for param in model.parameters():
            if param.requires_grad:
                yield param

    optimizer = torch.optim.Adam(trainable_params(),
                                 args.learning_rate,
                                 weight_decay=1e-5)

    best_performance = 0
    if args.resume:
        train_F = open(os.path.join(args.log_dir, 'train.csv'), 'a')
        val_F = open(os.path.join(args.log_dir, 'val.csv'), 'a')
        score_F = open(os.path.join(args.log_dir, 'score.csv'), 'a')
        if os.path.isfile(os.path.join(args.log_dir, 'checkpoint.pth.tar')):
            print("=> loading checkpoint '{}'".format(args.log_dir))
            checkpoint = torch.load(
                os.path.join(args.log_dir, 'checkpoint.pth.tar'))
            args.start_epoch = checkpoint['epoch']
            best_performance = checkpoint['best_performance']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))
    else:
        train_F = open(os.path.join(args.log_dir, 'train.csv'), 'w')
        val_F = open(os.path.join(args.log_dir, 'val.csv'), 'w')
        score_F = open(os.path.join(args.log_dir, 'score.csv'), 'w')

    for epoch in range(args.start_epoch, args.num_epochs + 1):
        train(args, epoch, model, criterion, train_loader, optimizer, train_F,
              score_F)
        current_performance = test(args, epoch, model, criterion, val_loader,
                                   optimizer, val_F, score_F)
        is_best = current_performance > best_performance
        best_performance = max(current_performance, best_performance)
        model_state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_performance': best_performance
        }
        save_checkpoint(args, model_state, is_best,
                        os.path.join(args.log_dir, 'checkpoint.pth.tar'))
        os.system('python plot.py {} &'.format(args.log_dir))

    train_F.close()
    val_F.close()
    score_F.close()