コード例 #1
0
def main():
    #create model
    best_prec1 = 0
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)
    model = 0
    if args.basenet == 'MultiModal':
        model = MultiModalNet('se_resnet50', 'DPN26', 0.5)
        #model = MultiModalNet('se_resnet50', 'DPN26', 0.5)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    elif  args.basenet == 'oct_resnet101':
        model = oct_resnet101()
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    assert not model==0
    model = model.cuda()
    cudnn.benchmark = True
    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root = args.dataset_root, mode = 'MM_cleaned_train', transform = Aug)
    #weights = [class_ration[label] for data,label in Dataset_train]

    Dataloader_train = data.DataLoader(Dataset_train, args.batch_size, 
                                 num_workers = args.num_workers,
                                 shuffle = True, pin_memory = True)

    Dataset_val = MM_BDXJTU2019(root = args.dataset_root, mode = 'val')
    Dataloader_val = data.DataLoader(Dataset_val, batch_size = 8,
                                 num_workers = args.num_workers,
                                 shuffle = True, pin_memory = True)

    criterion = nn.CrossEntropyLoss(weight = weights).cuda()

    print("info:",args.batch_size)
    print("load pretrained model from model3 _____21_6000___adam")
    state_dict1 = torch.load('model3/BDXJTU2019_SGD_21_6000.pth')
    model.load_state_dict(state_dict1, strict=False)

    # Optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = args.lr, momentum = args.momentum,
    #                       weight_decay = args.weight_decay)

    Optimizer=optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=0.001)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)

        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch)    #train(Dataloader_train, Network, criterion, Optimizer, epoch)

        # evaluate on validation set
        #_,_ = validate(Dataloader_val, model, criterion)  #prec1 = validate(Dataloader_val, Network, criterion)

        # remember best prec@1 and save checkpoint
        #is_best = prec1 > best_prec1
        #best_prec1 = max(prec1, best_prec1)
        #if is_best:
        if epoch%1 == 0:
            1#torch.save(model.state_dict(), 'model3/BDXJTU2019_SGD_' + repr(epoch) + '.pth')
コード例 #2
0
def GetNetTensor():
    # Priors
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    Dataset_val = MM_BDXJTU2019(root = 'data', mode = 'val', TRAIN_IMAGE_DIR = 'train_image_raw')
    Dataloader_val = data.DataLoader(Dataset_val, 1, num_workers = 0, shuffle = False, pin_memory = False)

    # Network
    cudnn.benchmark = True
    MODEL_NAME = [i for i in os.listdir('./weights') if i!='best_models'][0]
    BEST_DIR = op.join('weights','best_models')

    # Find & Load Best_Model
    if op.isdir(BEST_DIR) and len(os.listdir(BEST_DIR))>0:
        best_model = MultiModalNet(MODEL_NAME, 'DPN26', 0.5).cuda()
        pthlist = [i for i in os.listdir(BEST_DIR) if i[-4:]=='.pth']
        pthlist.sort(key=lambda x:eval(re.findall(r'\d+',x)[-1]))
        best_model.load_state_dict(torch.load(op.join(BEST_DIR,pthlist[-1])))
        best_model.eval()
        TensorGenerator( Dataloader_val, NET_VAL_DATA_PATH, best_model)
    else:
        print('No best_model pth found in ./weights/best_models.')
コード例 #3
0
def GeResult():
    # Dataset
    Dataset_val = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                mode='1_val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=1,
                                     num_workers=4,
                                     shuffle=True,
                                     pin_memory=True)

    class_names = [
        '001', '002', '003', '004', '005', '006', '007', '008', '009'
    ]

    # construct network
    epoch = 12
    net = MultiModalNet('se_resnet152', 'DPN26', 0.5)
    #    if torch.cuda.device_count() > 1:
    #        print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    #        net = nn.DataParallel(net)
    net.to(device)
    #    net.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet152_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_' + str(epoch) + '.pth'))
    net.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet152_se_resnext50_32x4d_resample_pretrained_80w_1/inception_005.pth'
        ))
    print('load ' + str(epoch) + ' epoch model')
    net.eval()

    results = []
    results_anno = []

    for i, (Input_img, Input_vis, Anno) in enumerate(Dataloader_val):
        Input_img = Input_img.to(device)
        Input_vis = Input_vis.to(device)

        ConfTensor = net.forward(Input_img, Input_vis)
        _, pred = ConfTensor.data.topk(1, 1, True, False)

        results.append(pred.item())

        results_anno.append(Anno)  #append annotation results
        if ((i + 1) % 1000 == 0):
            print(i + 1)
            print('Accuracy of Orignal Input: %0.6f' %
                  (accuracy_score(results, results_anno, normalize=True)))

    # print accuracy of different input
    print('Accuracy of Orignal Input: %0.6f' %
          (accuracy_score(results, results_anno, normalize=True)))

    cnf_matrix = confusion_matrix(results_anno, results)
    cnf_tr = np.trace(cnf_matrix)

    cnf_tr = cnf_tr.astype('float')
    print(cnf_tr / len(Dataset_val))
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          title='Confusion matrix, without normalization')
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          normalize=True,
                          title='Normalized confusion matrix')
    plt.show()
コード例 #4
0
def main():
    #create model
    best_prec1 = 0

    if args.basenet == 'se_resnet152':
        model = MultiModalNet('se_resnet152', 'DPN26', 0.5)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    elif args.basenet == 'se_resnext50_32x4d':
        model = MultiModalNet1('se_resnext50_32x4d', 'DPN26', 0.5)
    elif args.basenet == 'se_resnet50':
        model = MultiModalNet1('se_resnet50', 'DPN26', 0.5)
    elif args.basenet == 'densenet201':
        model = MultiModalNet2('densenet201', 'DPN26', 0.5)

    elif args.basenet == 'oct_resnet101':
        model = oct_resnet101()
#    print("load pretrained model from /home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_4.pth")
#    pre='/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_1.pth'
#    model.load_state_dict(torch.load(pre))
#net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
    model.to(device)
    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                  mode='MM_1_train',
                                  transform=Aug)
    #weights = [class_ration[label] for data,label in Dataset_train]

    Dataloader_train = data.DataLoader(Dataset_train,
                                       128,
                                       num_workers=4,
                                       shuffle=True,
                                       pin_memory=True)

    Dataset_val = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=32,
                                     num_workers=4,
                                     shuffle=True,
                                     pin_memory=True)

    #    criterion = nn.CrossEntropyLoss(weight = weights).cuda()
    criterion = nn.CrossEntropyLoss().to(device)
    #    Optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = args.lr, momentum = args.momentum,
    #                          weight_decay = args.weight_decay)
    Optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)

        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch
              )  #train(Dataloader_train, Network, criterion, Optimizer, epoch)

        # evaluate on validation set
        #_,_ = validate(Dataloader_val, model, criterion)  #prec1 = validate(Dataloader_val, Network, criterion)

        # remember best prec@1 and save checkpoint
        #is_best = prec1 > best_prec1
        #best_prec1 = max(prec1, best_prec1)
        #if is_best:
        if epoch % 1 == 0:
            torch.save(
                model.module.state_dict(), 'weights/' + args.basenet +
                '_se_resnext50_32x4d_resample_pretrained_80w_1/' +
                'BDXJTU2019_SGD_' + repr(epoch) + '.pth')
コード例 #5
0
def main():
    ###Enter Main Func:
    mp.set_start_method('spawn')
    #create model
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)

    MODEL_NAME = 'multiscale_se_resnext_HR'
    MODEL_DIR = op.join('weights', MODEL_NAME)
    BEST_DIR = op.join('weights', 'best_models')
    if not op.isdir(MODEL_DIR):
        os.mkdir(MODEL_DIR)
    if not op.isdir(BEST_DIR):
        os.mkdir(BEST_DIR)

    #    if args.basenet == 'MultiModal':
    model = MultiModalNet(MODEL_NAME, 'DPN26', 0.5)
    #        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    #    elif  args.basenet == 'oct_resnet101':
    #        model = oct_resnet101()
    #        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])

    model = model.cuda()
    cudnn.benchmark = True
    RESUME = False
    # MODEL_PATH = './weights/best_models/se_resnext50_32x4d_SGD_w_46.pth'
    pthlist = [i for i in os.listdir(MODEL_DIR) if i[-4:] == '.pth']
    if len(pthlist) > 0:
        pthlist.sort(key=lambda x: eval(re.findall(r'\d+', x)[-1]))
        MODEL_PATH = op.join(MODEL_DIR, pthlist[-1])
        model.load_state_dict(torch.load(MODEL_PATH))
        RESUME = True

    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root=args.dataset_root,
                                  mode='train',
                                  transform=Aug,
                                  TRAIN_IMAGE_DIR='train_image_raw')
    #weights = [class_ration[label] for data,label in Dataset_train]
    Dataloader_train = data.DataLoader(Dataset_train,
                                       args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       pin_memory=True)

    Dataset_val = MM_BDXJTU2019(root=args.dataset_root,
                                mode='val',
                                TRAIN_IMAGE_DIR='train_image_raw')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=64,
                                     num_workers=args.num_workers,
                                     shuffle=True,
                                     pin_memory=True)

    criterion = nn.CrossEntropyLoss(weight=weights).cuda()
    Optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    args.start_epoch = eval(re.findall(r'\d+',
                                       MODEL_PATH)[-1]) if RESUME else 0
    best_pred1, best_preds = 0, {}
    if RESUME and op.isfile(best_log):
        best_preds = json.load(open(best_log))
        best_pred1 = best_preds['best_pred1']
    elif len(os.listdir(BEST_DIR)) > 0:
        best_model = MultiModalNet(MODEL_NAME, 'DPN26', 0.5).cuda()
        pthlist = [i for i in os.listdir(BEST_DIR) if i[-4:] == '.pth']
        pthlist.sort(key=lambda x: eval(re.findall(r'\d+', x)[-1]))
        best_model.load_state_dict(torch.load(op.join(BEST_DIR, pthlist[-1])))
        best_pred1 = validate(Dataloader_val,
                              best_model,
                              criterion,
                              printable=False)[0]
        log_dict(eval(re.findall(r'\d+', pthlist[-1])[-1]), best_pred1)

    log('#Resume: Another Start from Epoch {}'.format(args.start_epoch + 1))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)
        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch
              )  #train(Dataloader_train, Network, criterion, Optimizer, epoch)
        # evaluate on validation set
        pred1, pred5 = validate(
            Dataloader_val, model,
            criterion)  #pred1 = validate(Dataloader_val, Network, criterion)
        # remember best pred@1 and save checkpoint

        COMMON_MODEL_PATH = op.join(MODEL_DIR,
                                    'SGD_fold1_{}.pth'.format(epoch + 1))
        BEST_MODEL_PATH = op.join(
            BEST_DIR, '{}_SGD_fold1_{}.pth'.format(MODEL_NAME, epoch + 1))
        torch.save(model.state_dict(), COMMON_MODEL_PATH)
        if pred1 > best_pred1:
            best_pred1 = max(pred1, best_pred1)
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            log_dict(epoch + 1, best_pred1)
        log('Epoch:{}\tpred1:{}\tBest_pred1:{}\n'.format(
            epoch + 1, pred1, best_pred1))
コード例 #6
0
def GeResult():

    # Dataset
    Dataset_val = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=1,
                                     num_workers=2,
                                     shuffle=True,
                                     pin_memory=True)

    class_names = [
        '001', '002', '003', '004', '005', '006', '007', '008', '009'
    ]

    net1 = MultiModalNet1('se_resnet50', 'DPN26', 0.5)
    net1.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet50_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_9.pth'
        ))
    net1.to(device)
    net1.eval()
    net2 = MultiModalNet('se_resnet152', 'DPN26', 0.5)
    net2.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet152_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_4.pth'
        ))
    net2.to(device)
    net2.eval()
    net3 = MultiModalNet2('densenet201', 'DPN26', 0.5)
    net3.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_3.pth'
        ))
    net3.to(device)
    net3.eval()
    # construct network
    #    net1 =MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net1.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/models/BDXJTU2019_SGD_16.pth'))
    #    net1.eval()

    #    net2 = MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net2.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/models/BDXJTU2019_SGD_26.pth'))
    #    net2.eval()

    #    net3 =MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net3.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/models/BDXJTU2019_SGD_50.pth'))
    #    net3.eval()

    results = []
    results_anno = []

    for i, (Input_img, Input_vis, Anno) in enumerate(Dataloader_val):
        Input_img = Input_img.to(device)
        Input_vis = Input_vis.to(device)

        ConfTensor1 = net1.forward(Input_img, Input_vis)
        ConfTensor2 = net2.forward(Input_img, Input_vis)
        ConfTensor3 = net3.forward(Input_img, Input_vis)

        ConfTensor = (torch.nn.functional.normalize(ConfTensor1) +
                      torch.nn.functional.normalize(ConfTensor2) +
                      torch.nn.functional.normalize(ConfTensor3)) / 3

        score, pred = ConfTensor.data.topk(1, 1, True, False)
        #print(score.item())
        if (score.item() > 0.85):
            results.append(pred.item())

            results_anno.append(Anno)  #append annotation results
        if ((i + 1) % 2000 == 0):
            print(i + 1)
            print(len(results))
            print('Accuracy of Orignal Input: %0.6f' %
                  (accuracy_score(results, results_anno, normalize=True)))
    # print accuracy of different input
    print('Accuracy of Orignal Input: %0.6f' %
          (accuracy_score(results, results_anno, normalize=True)))

    cnf_matrix = confusion_matrix(results_anno, results)
    cnf_tr = np.trace(cnf_matrix)
    cnf_tr = cnf_tr.astype('float')
    print(cnf_tr / len(Dataset_val))
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          title='Confusion matrix, without normalization')
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          normalize=True,
                          title='Normalized confusion matrix')
    plt.show()
コード例 #7
0
ファイル: MM_val.py プロジェクト: wdczs/2019Baidu-XJTU_URFC
def GeResult():

    # Dataset
    Dataset_val = MM_BDXJTU2019(root='data', mode='1_val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=1,
                                     num_workers=2,
                                     shuffle=True,
                                     pin_memory=True)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)

    class_names = [
        '001', '002', '003', '004', '005', '006', '007', '008', '009'
    ]

    # construct network
    epoch = 80
    net = MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    net.load_state_dict(
        torch.load(
            '/home/zxw/2019BaiduXJTU/weights/MultiModal_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_'
            + str(epoch) + '.pth'))
    print('load ' + str(epoch) + ' epoch model')
    net.eval()

    results = []
    results_anno = []

    for i, (Input_img, Input_vis, Anno) in enumerate(Dataloader_val):
        Input_img = Input_img.cuda()
        Input_vis = Input_vis.cuda()

        ConfTensor = net.forward(Input_img, Input_vis)
        _, pred = ConfTensor.data.topk(1, 1, True, False)

        results.append(pred.item())

        results_anno.append(Anno)  #append annotation results
        if ((i + 1) % 1000 == 0):
            print(i + 1)
            print('Accuracy of Orignal Input: %0.6f' %
                  (accuracy_score(results, results_anno, normalize=True)))

    # print accuracy of different input
    print('Accuracy of Orignal Input: %0.6f' %
          (accuracy_score(results, results_anno, normalize=True)))

    cnf_matrix = confusion_matrix(results_anno, results)
    cnf_tr = np.trace(cnf_matrix)

    cnf_tr = cnf_tr.astype('float')
    print(cnf_tr / len(Dataset_val))
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          title='Confusion matrix, without normalization')
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          normalize=True,
                          title='Normalized confusion matrix')
    plt.show()
コード例 #8
0
ファイル: val.py プロジェクト: linan2017/URFC-top4
def GeResult():
    # Priors
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)

    model_names=["","se_resnext101_32x4d","se_resnext50_32x4d","se_resnet50","densenet169","densenet121"]
    weight_names=["","model1","model2","model3","model_densenet169","model_densenet121"]

    model_id=5
    batch_size=4
    print(model_id)
    model_name=model_names[model_id]
    weight_name=weight_names[model_id]

    weightfiles=os.listdir(weight_name)

    MAX=0
    for file in weightfiles:
        if file.count("_")<3:
            continue
        s=file.split("BDXJTU2019_SGD_")[1]
        s=s.split("_")
        s1=int(s[0])
        s2=int(s[1].split(".")[0])
        if MAX<s1*1000000+s2:
            MAX=s1*1000000+s2
            MAXfile=file
            MAXs1=s1
    print(MAXfile)
    # Dataset
    Dataset = MM_BDXJTU2019("data_txt", mode = 'val')
    if model_name.count("dense")>0:
        Dataset = MM_BDXJTU2019_for_dense("data_txt", mode='val')
    elif model_name.count("nasnet")>0:
        Dataset = MM_BDXJTU2019_for_nasnet("data_txt", mode='val')

    print(batch_size)
    Dataloader = data.DataLoader(Dataset,batch_size,
                                 num_workers=1,
                                 shuffle=False, pin_memory=True)

    # Network
    cudnn.benchmark = True
    # Network = pnasnet5large(6, None)
    # Network = ResNeXt101_64x4d(6)
    net = MultiModalNet(model_name, 'DPN26', 0.5)
    print("weights  model"+str(model_id))
    net.load_state_dict(torch.load(weight_name+'/'+MAXfile))

    net.eval()

    filename = 'val_result/['+str(model_id)+'_'+str(MAXs1)+'].txt'

    import numpy as np

    csvO=open('val_result/csvO['+str(model_id)+'_'+str(MAXs1)+'].csv',"w")

    cnt=0
    name_id=np.zeros([50000])
    for i,(Input_O,visit_tensor, anoss,ids) in enumerate(Dataloader):
        ConfTensor_Os = net.forward(Input_O.cuda(), visit_tensor.cuda())
        for id in range(ConfTensor_Os.shape[0]):
            ConfTensor_O=ConfTensor_Os[id].reshape([1,9])
            anos=[anoss[id]]
            name_id[cnt]=int(ids[id])
            preds_temp = torch.nn.functional.normalize(ConfTensor_O)
            string = ""
            for _ in range(9):
                string = string + str(float(preds_temp[0][_])) + ","
            string = string + "\n"
            csvO.write(string)
            ####################
            preds = torch.nn.functional.normalize(ConfTensor_O)
            _, pred = preds.data.topk(1, 1, True, True)
            if cnt%100==0:
                print(cnt)
            cnt+=1
    print(name_id[:5])
    np.save("name_id"+str(model_id)+".npy",name_id)
    csvO.close()
コード例 #9
0
def main():
    #create model
    best_prec1 = 0

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)

    if args.basenet == 'MultiModal':
        model = MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])

    elif args.basenet == 'oct_resnet101':
        model = oct_resnet101()
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])

    model = model.cuda()
    cudnn.benchmark = True

    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root=args.dataset_root,
                                  mode='MM_1_train',
                                  transform=Aug)
    #weights = [class_ration[label] for data,label in Dataset_train]

    Dataloader_train = data.DataLoader(Dataset_train,
                                       args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       pin_memory=True)

    Dataset_val = BDXJTU2019(root=args.dataset_root, mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=8,
                                     num_workers=args.num_workers,
                                     shuffle=True,
                                     pin_memory=True)

    criterion = nn.CrossEntropyLoss(weight=weights).cuda()

    Optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)

        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch
              )  #train(Dataloader_train, Network, criterion, Optimizer, epoch)

        # evaluate on validation set
        #_,_ = validate(Dataloader_val, model, criterion)  #prec1 = validate(Dataloader_val, Network, criterion)

        # remember best prec@1 and save checkpoint
        #is_best = prec1 > best_prec1
        #best_prec1 = max(prec1, best_prec1)
        #if is_best:
        if epoch % 1 == 0:
            torch.save(
                model.state_dict(), 'weights/' + args.basenet +
                '_se_resnext50_32x4d_resample_pretrained_80w_1/' +
                'BDXJTU2019_SGD_' + repr(epoch) + '.pth')