def main():

    global args, best_acc1
    args = parser.parse_args()

    # Check if CUDA is enabled
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # Load data
    root = args.datasetPath

    print('Prepare files')
    files = [
        f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))
    ]

    idx = np.random.permutation(len(files))
    idx = idx.tolist()

    valid_ids = [files[i] for i in idx[0:10000]]
    test_ids = [files[i] for i in idx[10000:20000]]
    train_ids = [files[i] for i in idx[20000:]]

    data_train = datasets.Qm9(root, train_ids)
    data_valid = datasets.Qm9(root, valid_ids)
    data_test = datasets.Qm9(root, test_ids)

    # Define model and optimizer
    print('Define model')
    # Select one graph
    g_tuple, l = data_train[0]
    g, h_t, e = g_tuple

    print('\tStatistics')
    # stat_dict = datasets.utils.get_graph_stats(data_valid, ['degrees', 'target_mean', 'target_std', 'edge_labels'])

    stat_dict = {}

    stat_dict['degrees'] = [1, 2, 3, 4]
    stat_dict['target_mean'] = np.array([
        2.71802732e+00, 7.51685080e+01, -2.40259300e-01, 1.09503300e-02,
        2.51209430e-01, 1.18997445e+03, 1.48493130e-01, -4.11609491e+02,
        -4.11601022e+02, -4.11600078e+02, -4.11642909e+02, 3.15894998e+01
    ])
    stat_dict['target_std'] = np.array([
        1.58422291e+00, 8.29443552e+00, 2.23854977e-02, 4.71030547e-02,
        4.77156393e-02, 2.80754665e+02, 3.37238236e-02, 3.97717205e+01,
        3.97715029e+01, 3.97715029e+01, 3.97722334e+01, 4.09458852e+00
    ])
    stat_dict['edge_labels'] = [1, 2, 3, 4]

    data_train.set_target_transform(lambda x: datasets.utils.normalize_data(
        x, stat_dict['target_mean'], stat_dict['target_std']))
    data_valid.set_target_transform(lambda x: datasets.utils.normalize_data(
        x, stat_dict['target_mean'], stat_dict['target_std']))
    data_test.set_target_transform(lambda x: datasets.utils.normalize_data(
        x, stat_dict['target_mean'], stat_dict['target_std']))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(
        data_train,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        data_test,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)

    print('\tCreate model')
    model = MpnnDuvenaud(stat_dict['degrees'],
                         [len(h_t[0]), len(list(e.values())[0])], [5, 15, 15],
                         30,
                         len(l),
                         type='regression')

    print('Optimizer')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    criterion = nn.MSELoss()

    evaluation = lambda output, target: torch.mean(
        torch.abs(output - target) / torch.abs(target))

    print('Logger')
    logger = Logger(args.logPath)

    lr_step = (args.lr - args.lr * args.lr_decay) / (
        args.epochs * args.schedule[1] - args.epochs * args.schedule[0])

    # get the best checkpoint if available without training
    if args.resume:
        checkpoint_dir = args.resume
        best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
        if not os.path.isdir(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        if os.path.isfile(best_model_file):
            print("=> loading best model '{}'".format(best_model_file))
            checkpoint = torch.load(best_model_file)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded best model '{}' (epoch {})".format(
                best_model_file, checkpoint['epoch']))
        else:
            print("=> no best model found at '{}'".format(best_model_file))

    print('Check cuda')
    if args.cuda:
        print('\t* Cuda')
        model = model.cuda()
        criterion = criterion.cuda()

    # Epoch for loop
    for epoch in range(0, args.epochs):

        if epoch > args.epochs * args.schedule[
                0] and epoch < args.epochs * args.schedule[1]:
            args.lr -= lr_step
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, evaluation,
              logger)

        # evaluate on test set
        acc1 = validate(valid_loader, model, criterion, evaluation, logger)

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            },
            is_best=is_best,
            directory=args.resume)

        # Logger step
        logger.log_value('learning_rate', args.lr).step()

    # get the best checkpoint and test it with test set
    if args.resume:
        checkpoint_dir = args.resume
        best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
        if not os.path.isdir(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        if os.path.isfile(best_model_file):
            print("=> loading best model '{}'".format(best_model_file))
            checkpoint = torch.load(best_model_file)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded best model '{}' (epoch {})".format(
                best_model_file, checkpoint['epoch']))
        else:
            print("=> no best model found at '{}'".format(best_model_file))

    # For testing
    validate(test_loader, model, criterion, evaluation)
Exemple #2
0
def main():

    global args, best_acc1
    args = parser.parse_args()

    # Check if CUDA is enabled
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # Load data
    root = args.datasetPath
    subset = args.subSet

    print('Prepare files')

    train_classes, train_ids = read_cxl(os.path.join(root, subset,
                                                     'train.cxl'))
    test_classes, test_ids = read_cxl(os.path.join(root, subset, 'test.cxl'))
    valid_classes, valid_ids = read_cxl(
        os.path.join(root, subset, 'validation.cxl'))

    class_list = list(set(train_classes + test_classes))
    num_classes = len(class_list)
    data_train = datasets.LETTER(root, subset, train_ids, train_classes,
                                 class_list)
    data_valid = datasets.LETTER(root, subset, valid_ids, valid_classes,
                                 class_list)
    data_test = datasets.LETTER(root, subset, test_ids, test_classes,
                                class_list)

    # Define model and optimizer
    print('Define model')
    # Select one graph
    g_tuple, l = data_train[0]
    g, h_t, e = g_tuple

    #TODO: Need attention
    print('\tStatistics')
    stat_dict = {}
    # stat_dict = datasets.utils.get_graph_stats(data_train, ['edge_labels'])
    stat_dict['edge_labels'] = [1]

    # Data Loader
    train_loader = torch.utils.data.DataLoader(
        data_train,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        data_test,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)

    print('\tCreate model')
    model = MpnnIntNet([len(h_t[0]), len(list(e.values())[0])], [5, 15, 15],
                       [10, 20, 20],
                       num_classes,
                       type='classification')

    print('Optimizer')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    criterion = nn.NLLLoss()

    evaluation = utils.accuracy

    print('Logger')
    logger = Logger(args.logPath)

    lr_step = (args.lr - args.lr * args.lr_decay) / (
        args.epochs * args.schedule[1] - args.epochs * args.schedule[0])

    # get the best checkpoint if available without training
    if args.resume:
        checkpoint_dir = args.resume
        best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
        if not os.path.isdir(best_model_file):
            os.makedirs(checkpoint_dir)
        if os.path.isfile(best_model_file):
            print("=> loading best model '{}'".format(best_model_file))
            checkpoint = torch.load(best_model_file)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded best model '{}' (epoch {}; accuracy {})".format(
                best_model_file, checkpoint['epoch'], best_acc1))
        else:
            print("=> no best model found at '{}'".format(best_model_file))

    print('Check cuda')
    if args.cuda:
        print('\t* Cuda')
        model = model.cuda()
        criterion = criterion.cuda()

    # Epoch for loop
    for epoch in range(0, args.epochs):

        if epoch > args.epochs * args.schedule[
                0] and epoch < args.epochs * args.schedule[1]:
            args.lr -= lr_step
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, evaluation,
              logger)

        # evaluate on test set
        acc1 = validate(valid_loader, model, criterion, evaluation, logger)

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            },
            is_best=is_best,
            directory=args.resume)

        # Logger step
        logger.log_value('learning_rate', args.lr).step()

    # get the best checkpoint and test it with test set
    if args.resume:
        checkpoint_dir = args.resume
        best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
        if not os.path.isdir(best_model_file):
            os.makedirs(checkpoint_dir)
        if os.path.isfile(best_model_file):
            print("=> loading best model '{}'".format(best_model_file))
            checkpoint = torch.load(best_model_file)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded best model '{}' (epoch {}; accuracy {})".format(
                best_model_file, checkpoint['epoch'], best_acc1))
        else:
            print("=> no best model found at '{}'".format(best_model_file))

    # For testing
    validate(test_loader, model, criterion, evaluation)
Exemple #3
0
def main(opt):
    if not os.path.exists(opt.resume):
        os.makedirs(opt.resume)
    if not os.path.exists(opt.logroot):
        os.makedirs(opt.logroot)

    log_dir_name = str(opt.manualSeed) + '/'
    log_path = os.path.join(opt.logroot, log_dir_name)
    opt.resume = os.path.join(opt.resume, log_dir_name)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    #log_file_name = log_path + 'ucf_log_st.txt'
    #log_file_name = opt.logroot + 'ucf_log_st_'+str(opt.manualSeed)+'.txt'

    log_file_name = opt.logroot + 'something_log_v4.1_' + str(
        opt.manualSeed) + '.txt'

    with open(log_file_name, 'a+') as file:
        file.write('manualSeed is %d \n' % opt.manualSeed)
    paths = config.Paths()

    train_datalist = "/home/mcislab/zhaojw/AAAI/sth_train_list.txt"
    val_datalist = "/home/mcislab/zhaojw/AAAI/sth_val_list.txt"
    test_datalist = "/home/mcislab/zhaojw/AAAI/sth_test_list.txt"
    #test_datalist = '/home/mcislab/wangruiqi/IJCV2019/data/newsomething-check.txt'
    #opt.resume = os.path.join(opt.resume,log_dir_name)

    train_dataset = dataset(train_datalist, paths.sthv2_final, opt)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.workers,
                                  drop_last=False)

    val_dataset = dataset(val_datalist, paths.sthv2_final, opt)

    val_dataloader = DataLoader(val_dataset,
                                batch_size=opt.batch_size,
                                shuffle=False,
                                num_workers=opt.workers,
                                drop_last=False)

    test_dataset = dataset(test_datalist, paths.sthv2_final, opt)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=opt.batch_size,
                                 shuffle=False,
                                 num_workers=opt.workers,
                                 drop_last=False)

    model = sthv2_model.Model(opt)
    '''
    if opt.show:
        show(model)
        exit()
    '''

    optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=100,
                                                gamma=0.9)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = nn.NLLLoss()
    if opt.cuda:
        model.cuda()
        #criterion.cuda(opt.device_id)
        criterion1.cuda()
        criterion2.cuda()
    '''
    if opt.epoch != 0:
        if os.path.exists('./models/hmdb_split1/'+checkpoint_model_name):
            model.load_state_dict(torch.load('./models/hmdb_split1/' + checkpoint_model_name))
        else:
            print('model not found')
            exit()
    '''
    #Lin commented on Sept. 2nd
    #model.double()

    writer = SummaryWriter(log_dir=os.path.join(log_path, 'runs/'))
    # For training
    sum_test_acc = []
    best_acc = 0.
    #epoch_errors = list()
    avg_epoch_error = np.inf
    best_epoch_error = np.inf
    '''
    #haha, output Acc for each class
    test_load_dir = opt.resume
    #test_load_dir = '/home/mcislab/linhanxi/IJCV19_Experiments/sth_scale/something_scale5_M/ckpnothresh/ours'
    model.load_state_dict(torch.load(os.path.join(test_load_dir, 'model_best.pth'))['state_dict'])
    if opt.featdir:
        model.feat_mode()
    test_acc, output = test(0,test_dataloader, model, criterion1, criterion2, opt, writer, test_load_dir, is_test=True)
    exit()
    '''
    print("Test once to get a baseline.")
    loaded_checkpoint = utils.load_best_checkpoint(opt, model, optimizer)
    if loaded_checkpoint:
        opt, model, optimizer = loaded_checkpoint
        test_acc, output = test(51,
                                test_dataloader,
                                model,
                                criterion1,
                                criterion2,
                                opt,
                                writer,
                                log_file_name,
                                is_test=True)
        tmp_test_acc = np.mean(test_acc)
        if tmp_test_acc > best_acc:

            best_acc = tmp_test_acc

    print("Start to train.....")
    for epoch_i in range(opt.epoch, opt.niter):
        scheduler.step()

        train(epoch_i, train_dataloader, model, criterion1, criterion2,
              optimizer, opt, writer, log_file_name)
        #val_acc, val_out, val_error =test(valid_loader, model, criterion1,criterion2, opt, log_file_name, is_test=False)
        # Lin changed according to 'sth_pre_abl1' on Sept. 3rd
        test_acc, output = val(epoch_i,
                               val_dataloader,
                               model,
                               criterion1,
                               criterion2,
                               opt,
                               writer,
                               log_file_name,
                               is_test=True)
        #test_acc,_ = test(test_dataloader, model, criterion1, criterion2, opt, log_file_name, is_test=True)

        tmp_test_acc = np.mean(test_acc)
        sum_test_acc.append(test_acc)

        if tmp_test_acc > best_acc:
            is_best = True
            best_acc = tmp_test_acc

        else:
            is_best = False

        utils.save_checkpoint(
            {
                'epoch': epoch_i,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            is_best=is_best,
            directory=opt.resume)
        print("A training epoch finished!")

    #epoch_i =33

    # For testing
    print("Training finished.Start to test.")
    loaded_checkpoint = utils.load_best_checkpoint(opt, model, optimizer)
    if loaded_checkpoint:
        opt, model, optimizer = loaded_checkpoint
    # Lin changed according to 'sth_pre_abl1' on Sept. 3rd
    test_acc, output = test(epoch_i,
                            test_dataloader,
                            model,
                            criterion1,
                            criterion2,
                            opt,
                            writer,
                            log_file_name,
                            is_test=True)
    #test_acc,output = test(test_dataloader, model, criterion1,criterion2,  opt, log_file_name, is_test=True)
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print("ratio=0.1, test Accuracy:   %.2f " % (100. * test_acc[0][0]))
    print("ratio=0.2, test Accuracy:   %.2f " % (100. * test_acc[0][1]))
    print("ratio=0.3, test Accuracy:   %.2f " % (100. * test_acc[0][2]))
    print("ratio=0.4, test Accuracy:   %.2f " % (100. * test_acc[0][3]))
    print("ratio=0.5, test Accuracy:   %.2f " % (100. * test_acc[0][4]))
    print("ratio=0.6, test Accuracy:   %.2f " % (100. * test_acc[0][5]))
    print("ratio=0.7, test Accuracy:   %.2f " % (100. * test_acc[0][6]))
    print("ratio=0.8, test Accuracy:   %.2f " % (100. * test_acc[0][7]))
    print("ratio=0.9, test Accuracy:   %.2f " % (100. * test_acc[0][8]))
    print("ratio=1.0, test Accuracy:   %.2f " % (100. * test_acc[0][9]))
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
Exemple #4
0
def main(opt):
    if not os.path.exists(opt.resume):
        os.makedirs(opt.resume)
    if not os.path.exists(opt.logroot):
        os.makedirs(opt.logroot)
    #log_dir_name = 'split'+opt.split + '/'
    log_dir_name = 'split'+opt.split + '/'+str(opt.manualSeed)+'/'
    opt.resume = os.path.join(opt.resume,log_dir_name)
    log_path = os.path.join(opt.logroot,log_dir_name)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    #log_file_name = log_path + 'ucf_log_st.txt'
    log_file_name = log_path + 'ucf_log_v5.0_st_'+str(opt.manualSeed)+'.txt'

    with open(log_file_name,'a+') as file:
        file.write('manualSeed is %d \n' % opt.manualSeed)
        file.write('state_dim is %d \n' % opt.state_dim)
        file.write('num_bottleneck is %d \n' % opt.num_bottleneck)
    paths = config.Paths()

    train_datalist = '/home/mcislab/wangruiqi/IJCV2019/data/ucf101Vid_train_lin_split'+opt.split+'.txt'
    test_datalist = '/home/mcislab/wangruiqi/IJCV2019/data/ucf101Vid_val_lin_split'+opt.split+'.txt'
    #train_datalist = '/home/mcislab/wangruiqi/IJCV2019/data/test.txt'
    #test_datalist = '/home/mcislab/wangruiqi/IJCV2019/data/test.txt'
    #train_dataset = dataset(train_datalist, paths.detect_root_ucf_mmdet, paths.img_root_ucf,paths.rgb_res18_ucf,paths.rgb_res18_ucf, opt)
    train_dataset = dataset(train_datalist, paths.bninception_ucf,opt)####zhao changed ###paths.resnet50_ucf_rgbflow_same
    train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, drop_last=False)
    #test_dataset = dataset(test_datalist, paths.detect_root_ucf_mmdet, paths.img_root_ucf,paths.rgb_res18_ucf,paths.rgb_res18_ucf, opt)
    test_dataset = dataset(test_datalist, paths.bninception_ucf,opt)###paths.resnet50_ucf_rgbflow_same
    test_dataloader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers, drop_last=False)

    model = ucf_main_without_systhesizing_model.Model(opt)
    #optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
    optimizer = optim.SGD(model.parameters(), lr=opt.lr,momentum=opt.momentum)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.9)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = nn.NLLLoss()

    if opt.cuda:
        
        model.cuda()
        #criterion.cuda(opt.device_id)
        criterion1.cuda()
        criterion2.cuda()

    
    '''
    if opt.epoch != 0:
        if os.path.exists('./models/hmdb_split1/'+checkpoint_model_name):
            model.load_state_dict(torch.load('./models/hmdb_split1/' + checkpoint_model_name))
        else:
            print('model not found')
            exit()
    '''
    #Lin commented on Sept. 2nd
    #model.double()


    writer = SummaryWriter(log_dir=log_path+'runs/')
    # For training
    sum_test_acc = []
    best_acc = 0.
    epoch_errors = list()
    avg_epoch_error = np.inf
    best_epoch_error = np.inf
    '''
    #haha, output Acc for each class
    model.load_state_dict(torch.load('/home/mcislab/linhanxi/ucf101_NewFeat_RGBtuned/ckpnothresh/ours/model_best.pth')['state_dict'])
    test_acc, output = test(0,test_dataloader, model, criterion1, criterion2, opt, writer, log_file_name, is_test=True)
    exit()
    
    #load last experiment best model
    print("load last experiment best model")
    model.load_state_dict(torch.load('/home/mcislab/zhaojw/AAAI/prediction2020/models/ucf101_res50/split1/1050/model_best.pth')['state_dict'])
    test_acc, output = test(0,test_dataloader, model, criterion1, criterion2, opt, writer, log_file_name, is_test=True)
    '''
    
    print ("Test once for a baseline.")
    loaded_checkpoint =utils.load_best_checkpoint(opt, model, optimizer)
    if loaded_checkpoint:
        #opt, model, optimizer = loaded_checkpoint
        opt, model, __ = loaded_checkpoint
        test_acc, output = test(1,test_dataloader, model, criterion1, criterion2, opt, writer, log_file_name, is_test=True)
        tmp_test_acc = np.mean(test_acc)
        if tmp_test_acc > best_acc:
         
            best_acc = tmp_test_acc
        

    print ("Start to train.....")
    #model.load_state_dict(torch.load('/home/mcislab/linhanxi/ucf101_flowOnly/ckpnothresh/ours/checkpoint.pth')['state_dict'])
    for epoch_i in range(opt.epoch, opt.niter):
        scheduler.step()
        
        train(epoch_i, train_dataloader, model, criterion1, criterion2,  optimizer, opt, writer, log_file_name)
        #val_acc, val_out, val_error =test(valid_loader, model, criterion1,criterion2, opt, log_file_name, is_test=False)
        # Lin changed according to 'sth_pre_abl1' on Sept. 3rd
        test_acc, output = test(epoch_i,test_dataloader, model, criterion1, criterion2, opt, writer, log_file_name, is_test=True)
        #test_acc,_ = test(test_dataloader, model, criterion1, criterion2, opt, log_file_name, is_test=True)
        
        tmp_test_acc = np.mean(test_acc)
        sum_test_acc.append(test_acc)
     
        if tmp_test_acc > best_acc:
            is_best = True
            best_acc = tmp_test_acc

        else:
            is_best = False

        utils.save_checkpoint({'epoch': epoch_i , 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
                              is_best=is_best, directory=opt.resume)
        print ("A training epoch finished!")
       
    # For testing
   
    print ("Training finished.Start to test.")
    loaded_checkpoint = utils.load_best_checkpoint(opt, model, optimizer)
    if loaded_checkpoint:
        opt, model, __ = loaded_checkpoint
    # Lin changed according to 'sth_pre_abl1' on Sept. 3rd
    test_acc,output = test(epoch_i,test_dataloader, model, criterion1, criterion2, opt, writer, log_file_name, is_test=True)
    #test_acc,output = test(test_dataloader, model, criterion1,criterion2,  opt, log_file_name, is_test=True)
    print ("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print ("ratio=0.1, test Accuracy:   %.2f " % (100. * test_acc[0][0]))
    print ("ratio=0.2, test Accuracy:   %.2f " % (100. * test_acc[0][1]))
    print ("ratio=0.3, test Accuracy:   %.2f " % (100. * test_acc[0][2]))
    print ("ratio=0.4, test Accuracy:   %.2f " % (100. * test_acc[0][3]))
    print ("ratio=0.5, test Accuracy:   %.2f " % (100. * test_acc[0][4]))
    print ("ratio=0.6, test Accuracy:   %.2f " % (100. * test_acc[0][5]))
    print ("ratio=0.7, test Accuracy:   %.2f " % (100. * test_acc[0][6]))
    print ("ratio=0.8, test Accuracy:   %.2f " % (100. * test_acc[0][7]))
    print ("ratio=0.9, test Accuracy:   %.2f " % (100. * test_acc[0][8]))
    print ("ratio=1.0, test Accuracy:   %.2f " % (100. * test_acc[0][9]))
    print ("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
Exemple #5
0
def main():

    global args, best_er1
    args = parser.parse_args()

    # Check if CUDA is enabled
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # Load data
    root = args.datasetPath

    print('Prepare files')
    files = [
        f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))
    ]

    idx = np.random.permutation(len(files))
    idx = idx.tolist()

    valid_ids = [files[i] for i in idx[0:10000]]
    test_ids = [files[i] for i in idx[10000:20000]]
    train_ids = [files[i] for i in idx[20000:]]

    data_train = datasets.Qm9(root,
                              train_ids,
                              edge_transform=utils.qm9_edges,
                              e_representation='raw_distance')
    data_valid = datasets.Qm9(root,
                              valid_ids,
                              edge_transform=utils.qm9_edges,
                              e_representation='raw_distance')
    data_test = datasets.Qm9(root,
                             test_ids,
                             edge_transform=utils.qm9_edges,
                             e_representation='raw_distance')

    # Define model and optimizer
    print('Define model')
    # Select one graph
    g_tuple, l = data_train[0]
    g, h_t, e = g_tuple

    print('\tStatistics')
    stat_dict = datasets.utils.get_graph_stats(data_valid,
                                               ['target_mean', 'target_std'])

    data_train.set_target_transform(lambda x: datasets.utils.normalize_data(
        x, stat_dict['target_mean'], stat_dict['target_std']))
    data_valid.set_target_transform(lambda x: datasets.utils.normalize_data(
        x, stat_dict['target_mean'], stat_dict['target_std']))
    data_test.set_target_transform(lambda x: datasets.utils.normalize_data(
        x, stat_dict['target_mean'], stat_dict['target_std']))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(
        data_train,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        data_test,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)

    print('\tCreate model')
    in_n = [len(h_t[0]), len(list(e.values())[0])]
    hidden_state_size = 73
    message_size = 73
    n_layers = 3
    l_target = len(l)
    type = 'regression'
    model = MPNN(in_n,
                 hidden_state_size,
                 message_size,
                 n_layers,
                 l_target,
                 type=type)
    del in_n, hidden_state_size, message_size, n_layers, l_target, type

    print('Optimizer')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    criterion = nn.MSELoss()

    def evaluation(output, target):
        return torch.mean(torch.abs(output - target) / torch.abs(target))

    print('Logger')
    logger = Logger(args.logPath)

    lr_step = (args.lr - args.lr * args.lr_decay) / (
        args.epochs * args.schedule[1] - args.epochs * args.schedule[0])

    # get the best checkpoint if available without training
    if args.resume:
        checkpoint_dir = args.resume
        best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
        if not os.path.isdir(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        if os.path.isfile(best_model_file):
            print("=> loading best model '{}'".format(best_model_file))
            checkpoint = torch.load(best_model_file)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_er1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded best model '{}' (epoch {})".format(
                best_model_file, checkpoint['epoch']))
        else:
            print("=> no best model found at '{}'".format(best_model_file))

    print('Check cuda')
    if args.cuda:
        print('\t* Cuda')
        model = model.cuda()
        criterion = criterion.cuda()

    # Epoch for loop
    for epoch in range(0, args.epochs):

        if epoch > args.epochs * args.schedule[
                0] and epoch < args.epochs * args.schedule[1]:
            args.lr -= lr_step
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, evaluation,
              logger)

        # evaluate on test set
        er1 = validate(valid_loader, model, criterion, evaluation, logger)

        is_best = er1 > best_er1
        best_er1 = min(er1, best_er1)
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_er1': best_er1,
                'optimizer': optimizer.state_dict(),
            },
            is_best=is_best,
            directory=args.resume)

        # Logger step
        logger.log_value('learning_rate', args.lr).step()

    # get the best checkpoint and test it with test set
    if args.resume:
        checkpoint_dir = args.resume
        best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
        if not os.path.isdir(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        if os.path.isfile(best_model_file):
            print("=> loading best model '{}'".format(best_model_file))
            checkpoint = torch.load(best_model_file)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_er1']
            model.load_state_dict(checkpoint['state_dict'])
            if args.cuda:
                model.cuda()
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded best model '{}' (epoch {})".format(
                best_model_file, checkpoint['epoch']))
        else:
            print("=> no best model found at '{}'".format(best_model_file))

    # For testing
    validate(test_loader, model, criterion, evaluation)
Exemple #6
0
def main(opt):
    if not os.path.exists(opt.resume):
        os.makedirs(opt.resume)
    if not os.path.exists(opt.logroot):
        os.makedirs(opt.logroot)

  
    paths = cad120_config.Paths()
    #print (paths.tmp_root)
    subject_ids = pickle.load(open(os.path.join(paths.tmp_root, 'cad120_data_list.p'), 'rb'))

    data_path = os.path.join(paths.tmp_root, 'cad120_data_pred.p')
    #data_path =  '/media/mcislab/wrq/CAD120/pred-feature/cad120_data_pred.p'
    test_acc_final = np.zeros([4, opt.seq_size])
    sub_index = 0
    
    resume_root = os.path.join(opt.resume,str(opt.manualSeed))
    
    for sub, seqs in subject_ids.items():  # cross-validation for each subject
        #sub='Subject'+str(opt.subject)
        log_file_name = opt.logroot+'cad120_log_sub'+sub+'.txt'
        with open(log_file_name,'a+') as file:
            file.write('manualSeed is %d \n' % opt.manualSeed)
        opt.resume = resume_root+ sub + '/'

        training_subject = pickle.load(open(os.path.join(paths.tmp_root, 'cad120_data_list.p'), 'rb'))  # if not reload it will delete both in subject_ids and training_sub
        testing_subject = dict()
        testing_subject[sub] = seqs
        
        del training_subject[sub]

        
        #print training_subject
        #print testing_subject

        training_set = CAD120(data_path, training_subject)

        testing_set = CAD120(data_path, testing_subject)

        #testing_set = CAD120(data_path, sequence_ids[-test_num:])


        train_loader = torch.utils.data.DataLoader(training_set, collate_fn=utils.collate_fn_cad,batch_size=opt.batch_size,
                                                   num_workers=opt.workers, shuffle=True, pin_memory=True)
        #valid_loader = torch.utils.data.DataLoader(valid_set, collate_fn=utils.collate_fn_cad,
        #                                           batch_size=opt.batch_size,
        #                                           num_workers=opt.workers, shuffle=False, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(testing_set,collate_fn=utils.collate_fn_cad,batch_size=opt.batch_size,
                                                  num_workers=opt.workers, shuffle=False, pin_memory=True)

        model = models.Model(opt) 
        optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.9)
        criterion1 = nn.CrossEntropyLoss()
        criterion2 = nn.NLLLoss()
        if opt.cuda:
            model.cuda()
            #criterion.cuda(opt.device_id)
            criterion1.cuda()
            criterion2.cuda()


        loaded_checkpoint = utils.load_best_checkpoint(opt, model, optimizer)
        if loaded_checkpoint:
            opt, model, optimizer = loaded_checkpoint
        '''
        if opt.epoch != 0:
            if os.path.exists('./models/hmdb_split1/'+checkpoint_model_name):
                model.load_state_dict(torch.load('./models/hmdb_split1/' + checkpoint_model_name))
            else:
                print('model not found')
                exit()
        '''
        #model.double()


        writer = SummaryWriter(log_dir=opt.logroot+'runs/'+sub+'/')
        # For training
        sum_test_acc = []
        best_acc = 0.
        epoch_errors = list()
        avg_epoch_error = np.inf
        best_epoch_error = np.inf
        
        for epoch_i in range(opt.epoch, opt.niter):
            scheduler.step()
            train(epoch_i, train_loader, model, criterion1,criterion2,  optimizer, opt, writer, log_file_name)
            #val_acc, val_out, val_error =test(valid_loader, model, criterion1,criterion2, opt, log_file_name, is_test=False)
            test_acc,output = test(epoch_i,test_loader, model, criterion1, criterion2, opt,writer, log_file_name)
            
            tmp_test_acc = np.mean(test_acc)
            sum_test_acc.append(test_acc)
            
            if tmp_test_acc > best_acc:
                is_best = True
                best_acc = tmp_test_acc


            else:
                is_best = False

            utils.save_checkpoint({'epoch': epoch_i + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
                                  is_best=is_best, directory=opt.resume)
        
        # For testing
        loaded_checkpoint = utils.load_best_checkpoint(opt, model, optimizer)
        if loaded_checkpoint:
            opt, model, optimizer = loaded_checkpoint
        test_acc,output = test(epoch_i,test_loader, model, criterion1, criterion2, opt,writer, log_file_name)
        print ("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        print ("ratio=0.1, test Accuracy:   %.2f " % (100. * test_acc[0][0]))
        print ("ratio=0.2, test Accuracy:   %.2f " % (100. * test_acc[0][1]))
        print ("ratio=0.3, test Accuracy:   %.2f " % (100. * test_acc[0][2]))
        print ("ratio=0.4, test Accuracy:   %.2f " % (100. * test_acc[0][3]))
        print ("ratio=0.5, test Accuracy:   %.2f " % (100. * test_acc[0][4]))
        print ("ratio=0.6, test Accuracy:   %.2f " % (100. * test_acc[0][5]))
        print ("ratio=0.7, test Accuracy:   %.2f " % (100. * test_acc[0][6]))
        print ("ratio=0.8, test Accuracy:   %.2f " % (100. * test_acc[0][7]))
        print ("ratio=0.9, test Accuracy:   %.2f " % (100. * test_acc[0][8]))
        print ("ratio=1.0, test Accuracy:   %.2f " % (100. * test_acc[0][9]))
        print ("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        sum_test_acc = np.array(sum_test_acc)
        sum_test_acc=sum_test_acc.reshape(opt.niter, opt.seq_size)
        scio.savemat(opt.logroot+sub+'_result.mat',{'test_acc':sum_test_acc})
        scio.savemat(opt.logroot+sub+'_output.mat',{'test_out':output})

        test_acc_final[sub_index, :] = test_acc
        
        with open(log_file_name, 'a+') as file:
            file.write("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
            file.write("ratio=0.1, test Accuracy: %.2f \n" % (100. * test_acc[0][0]))
            file.write("ratio=0.2, test Accuracy: %.2f \n" % (100. * test_acc[0][1]))
            file.write("ratio=0.3, test Accuracy: %.2f \n" % (100. * test_acc[0][2]))
            file.write("ratio=0.4, test Accuracy: %.2f \n" % (100. * test_acc[0][3]))
            file.write("ratio=0.5, test Accuracy: %.2f \n" % (100. * test_acc[0][4]))
            file.write("ratio=0.6, test Accuracy: %.2f \n" % (100. * test_acc[0][5]))
            file.write("ratio=0.7, test Accuracy: %.2f \n" % (100. * test_acc[0][6]))
            file.write("ratio=0.8, test Accuracy: %.2f \n" % (100. * test_acc[0][7]))
            file.write("ratio=0.9, test Accuracy: %.2f \n" % (100. * test_acc[0][8]))
            file.write("ratio=1.0, test Accuracy: %.2f \n" % (100. * test_acc[0][9]))
        sub_index = sub_index + 1
        writer.close()


    test_final = np.mean(test_acc_final,0)
    #print type(test_final)
    #print test_final
    print ("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print ("ratio=0.1, test Accuracy:   %.2f " % (100. * test_final[0]))
    print ("ratio=0.2, test Accuracy:   %.2f " % (100. * test_final[1]))
    print ("ratio=0.3, test Accuracy:   %.2f " % (100. * test_final[2]))
    print ("ratio=0.4, test Accuracy:   %.2f " % (100. * test_final[3]))
    print ("ratio=0.5, test Accuracy:   %.2f " % (100. * test_final[4]))
    print ("ratio=0.6, test Accuracy:   %.2f " % (100. * test_final[5]))
    print ("ratio=0.7, test Accuracy:   %.2f " % (100. * test_final[6]))
    print ("ratio=0.8, test Accuracy:   %.2f " % (100. * test_final[7]))
    print ("ratio=0.9, test Accuracy:   %.2f " % (100. * test_final[8]))
    print ("ratio=1.0, test Accuracy:   %.2f " % (100. * test_final[9]))
    print ("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    with open(log_file_name, 'a+') as file:
        file.write("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
        file.write("Cross-subject performance is:\n")
        file.write("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
        file.write("ratio=0.1, test Accuracy: %.2f \n" % (100. * test_final[0]))
        file.write("ratio=0.2, test Accuracy: %.2f \n" % (100. * test_final[1]))
        file.write("ratio=0.3, test Accuracy: %.2f \n" % (100. * test_final[2]))
        file.write("ratio=0.4, test Accuracy: %.2f \n" % (100. * test_final[3]))
        file.write("ratio=0.5, test Accuracy: %.2f \n" % (100. * test_final[4]))
        file.write("ratio=0.6, test Accuracy: %.2f \n" % (100. * test_final[5]))
        file.write("ratio=0.7, test Accuracy: %.2f \n" % (100. * test_final[6]))
        file.write("ratio=0.8, test Accuracy: %.2f \n" % (100. * test_final[7]))
        file.write("ratio=0.9, test Accuracy: %.2f \n" % (100. * test_final[8]))
        file.write("ratio=1.0, test Accuracy: %.2f \n" % (100. * test_final[9]))
Exemple #7
0
def main():
    global args
    args = parser.parse_args()

    # Check if CUDA is enabled
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # Load data
    root = args.datasetPath

    print('Prepare files')

    label_file = 'labels.txt'
    list_file = 'graphs.txt'
    with open(os.path.join(root, label_file), 'r') as f:
        l = f.read()
        classes = [int(float(s) > 0.5)
                   for s in l.split()]  #classes based on 0.5
        # just makes them all 1
        # print(set(classes))
        unique, counts = np.unique(np.array(classes), return_counts=True)
        print(dict(zip(unique, counts)))
    with open(os.path.join(root, list_file), 'r') as f:

        files = [s + '.pkl' for s in f.read().splitlines()]

    train_ids, train_classes, valid_ids, valid_classes, test_ids, test_classes = divide_datasets(
        files, classes)

    #shuffle here
    c = list(zip(train_ids, train_classes))

    random.shuffle(c)

    train_ids, train_classes = zip(*c)

    data_train = PrGr(root, train_ids, train_classes)
    print(data_train[0])
    print(len(data_train))
    data_valid = PrGr(root, valid_ids, valid_classes)
    data_test = PrGr(root, test_ids, test_classes)
    print(len(data_test))
    # Define model and optimizer
    print('Define model')
    # Select one graph
    g_tuple, l = data_train[6]
    g, h_t, e = g_tuple

    print('\tStatistics')
    stat_dict = datasets.utils.get_graph_stats(data_train, ['degrees'])

    # Data Loader
    train_loader = torch.utils.data.DataLoader(
        data_train,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        data_test,
        batch_size=args.batch_size,
        collate_fn=datasets.utils.collate_g,
        num_workers=args.prefetch,
        pin_memory=True)

    print('\tCreate model')
    num_classes = 2
    print(stat_dict['degrees'])
    model = MpnnDuvenaud(stat_dict['degrees'],
                         [len(h_t[0]), len(list(e.values())[0])], [7, 3, 5],
                         11,
                         num_classes,
                         type='classification')

    print('Check cuda')

    print('Optimizer')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    criterion = nn.NLLLoss()

    evaluation = utils.accuracy

    print('Logger')
    logger = Logger(args.logPath)

    lr_step = (args.lr - args.lr * args.lr_decay) / (
        args.epochs * args.schedule[1] - args.epochs * args.schedule[0])

    ### get the best checkpoint if available without training
    best_acc1 = 0
    # if args.resume:
    #     checkpoint_dir = args.resume
    #     best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
    #     if not os.path.isdir(checkpoint_dir):
    #         os.makedirs(checkpoint_dir)
    #     if os.path.isfile(best_model_file):
    #         print("=> loading best model '{}'".format(best_model_file))
    #         checkpoint = torch.load(best_model_file)
    #         args.start_epoch = checkpoint['epoch']
    #         best_acc1 = checkpoint['best_acc1']
    #         model.load_state_dict(checkpoint['state_dict'])
    #         optimizer.load_state_dict(checkpoint['optimizer'])
    #         print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'],
    #                                                                          best_acc1))
    #     else:
    #         print("=> no best model found at '{}'".format(best_model_file))

    print('Check cuda')
    if args.cuda:
        print('\t* Cuda')
        model = model.cuda()
        criterion = criterion.cuda()

    # Epoch for loop
    for epoch in range(0, args.epochs):

        if epoch > args.epochs * args.schedule[
                0] and epoch < args.epochs * args.schedule[1]:
            args.lr -= lr_step
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, evaluation,
              logger)

        # evaluate on test set
        acc1 = validate(valid_loader, model, criterion, evaluation, logger)

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            },
            is_best=is_best,
            directory=args.resume)

        # Logger step
        logger.log_value('learning_rate', args.lr).step()

    # get the best checkpoint and test it with test set
    # if args.resume:
    #     checkpoint_dir = args.resume
    #     best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
    #     if not os.path.isdir(checkpoint_dir):
    #         os.makedirs(checkpoint_dir)
    #     if os.path.isfile(best_model_file):
    #         print("=> loading best model '{}'".format(best_model_file))
    #         checkpoint = torch.load(best_model_file)
    #         args.start_epoch = checkpoint['epoch']
    #         best_acc1 = checkpoint['best_acc1']
    #         model.load_state_dict(checkpoint['state_dict'])
    #         optimizer.load_state_dict(checkpoint['optimizer'])
    #         print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'],
    #                                                                          best_acc1))
    #     else:
    #         print("=> no best model found at '{}'".format(best_model_file))

    # For testing
    validate(test_loader, model, criterion, evaluation)
    torch.save(model, 'test.pth')
    print(train_classes)
    print(valid_classes)
    print(test_classes)
Exemple #8
0
def main():
    global args, best_loss

    # create necessary folders
    os.makedirs('./checkpoints', exist_ok=True)
    os.makedirs('./runs', exist_ok=True)

    # create datasets
    train_dataset = dsets.CARVANA(root=args.dir,
                                  train=True,
                                  transform=unet_transforms.Compose([
                                      unet_transforms.Scale((256, 256)),
                                      unet_transforms.ToNumpy(),
                                      unet_transforms.RandomShiftScaleRotate(),
                                      unet_transforms.ToTensor(),
                                  ]))

    val_dataset = dsets.CARVANA(root=args.dir,
                                train=True,
                                transform=unet_transforms.Compose([
                                    unet_transforms.Scale((256, 256)),
                                    unet_transforms.ToNumpy(),
                                    unet_transforms.ToTensor()
                                ]))

    num_val = int(len(train_dataset) * args.perc_val)
    num_train = len(train_dataset) - num_val

    # define the dataloader with the previous dataset
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        # sampler=data_samplers.ChunkSampler(num_train, start=0),
        num_workers=4)

    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        sampler=data_samplers.ChunkSampler(num_val, start=num_train),
        num_workers=4)

    # create model, define the loss function and the optimizer.
    # Move everything to cuda

    model = models.UNet1024().cuda()

    criterion = {
        'loss': models.BCEplusDice().cuda(),
        'acc': models.diceAcc().cuda()
    }
    optimizer = optim.SGD(model.parameters(),
                          weight_decay=args.weight_decay,
                          lr=args.lr,
                          momentum=args.momentum,
                          nesterov=args.nesterov)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, checkpoint['epoch'])
        return

    # run the training loop
    for epoch in range(args.start_epoch, args.epochs):
        # adjust lr according to the current epoch
        model_utils.adjust_learning_rate(optimizer, epoch, args.lr,
                                         args.adjust_epoch)

        # train for one epoch
        curr_loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate the model
        # curr_loss = validate(val_loader, model, criterion, epoch)

        # store best loss and save a model checkpoint
        is_best = curr_loss < best_loss
        best_loss = max(curr_loss, best_loss)
        data_utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            folder='./checkpoints/',
            filename=args.arch)

    logger.close()