コード例 #1
0
def get_learn(data):
    # create model
    model = nb_resnet_unet.get_unet_res18(1, True)
    model.load_state_dict(torch.load('./models/unet_res18_allres_init.pth'))

    # create learner
    learn = Learner(data, model)

    # split model
    learn.layer_groups = split_model(learn.model)

    # set multi-gpu
    if data.device.type == 'cuda':
        learn.model = torch.nn.DataParallel(learn.model,
                                            device_ids=[0, 1, 2, 3])

    # set loss func
#     learn.loss_func = partial(nb_loss_metrics.combo_loss, balance_ratio=1)
#     learn.loss_func = nb_loss_metrics.dice_loss
    learn.loss_func = partial(nb_loss_metrics.balance_bce, balance_ratio=1)

    # 添加metrics
    learn.metrics += [nb_loss_metrics.dice_loss]
    learn.metrics += [partial(nb_loss_metrics.balance_bce, balance_ratio=1)]
    learn.metrics += [nb_loss_metrics.mask_iou]

    return learn
コード例 #2
0
def get_learn_detectsym_17clas(data, gaf, clas_weights=weights):
    '''
    用的符号检测的17个类别的数据集
    '''
    # create model
    model = resnet_ssd.get_resnet18_1ssd(num_classes=17)
    model.load_state_dict(torch.load('./models/pretrained_res18_1ssd.pth'))

    # create learner
    learn = Learner(data, model)

    # split model
    learn.layer_groups = split_model(learn.model)

    # set multi-gpu
    if data.device.type == 'cuda':
        learn.model = torch.nn.DataParallel(
            learn.model, device_ids=device_ids)  #device_ids=[0,1,2,3,4,5])

    # set loss func
    learn.loss_func = partial(anchors_loss_metrics.yolo_L,
                              gaf=gaf,
                              conf_th=1,
                              clas_weights=clas_weights,
                              lambda_nconf=10)

    # 添加metrics
    learn.metrics += [
        partial(anchors_loss_metrics.clas_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.cent_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.pConf_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.nConf_L, gaf=gaf, conf_th=1)
    ]
    learn.metrics += [partial(anchors_loss_metrics.clas_acc, gaf=gaf)]
    learn.metrics += [partial(anchors_loss_metrics.cent_d, gaf=gaf)]

    return learn
コード例 #3
0
v_data = DLDataLoader(v_chain,
                      collate_fn=dlc.gdf_col,
                      pin_memory=False,
                      num_workers=0)

databunch = DataBunch(t_data, v_data, collate_fn=dlc.gdf_col, device="cuda")
t_final = time() - start
print(t_final)
print("Creating model")
start = time()
model = TabularModel(emb_szs=embeddings,
                     n_cont=len(cont_names),
                     out_sz=2,
                     layers=[512, 256])
learn = Learner(databunch, model, metrics=[accuracy])
learn.loss_func = torch.nn.CrossEntropyLoss()
t_final = time() - start
print(t_final)
print("Finding learning rate")
start = time()
learn.lr_find()
learn.recorder.plot(show_moms=True, suggestion=True)
learning_rate = 1.32e-2
epochs = 1
t_final = time() - start
print(t_final)
print("Running Training")
start = time()
learn.fit_one_cycle(epochs, learning_rate)
t_final = time() - start
print(t_final)
コード例 #4
0
#learn.load("/scratch/leuven/412/vsc41276/mp/CNN_16epochs", strict=False)

learn.split(_schnet_split)

learn.freeze()

summary(model, device)

if args.sdr:
    learn.opt_func = Adam_sdr#partial(SGD_sdr, weight_decay=0)#
else:
    learn.opt_func = Adam
    
if args.triplets_online:    
    learn.loss_func = TripletLoss(margin=1.0)

    learn.callbacks.append(TripletSetter(model, train_loader, train_loader2, semihard_negative, margin=1.0, triplets_per_class=100))
    learn.callbacks.append(TripletSetter(model, val_loader  , val_loader2  , semihard_negative, margin=1.0, triplets_per_class=125))
else:
    learn.loss_func = OnlineTripletLoss(1.0 ,SemihardNegativeTripletSelector(margin=1.0))

#learn.lr_find(start_lr=1e-6,end_lr=5e-2,no_grad_val = False,num_it=300)
#plot_recorder(learn.recorder)

plt.show()

torch.cuda.empty_cache()
learn.fit_one_cycle(cyc_len=1, max_lr=1e-2, div_factor=20.0, pct_start=0.25,no_grad_val=False) #,wd=#1e-7#,moms=(0.95, 0.85)

#learn.save('16epochs_triplets')