Esempio n. 1
0
def save_network(network, epoch_label):
    save_filename = 'net_%s.pth' % epoch_label
    save_path = os.path.join('./checks/mpn_no_ca/1501/1/',save_filename)
    torch.save(network.cpu().state_dict(), save_path)
    if torch.cuda.is_available:
        network.to(device)

#########################################################
# define student
TSModel = TS(num_classes=751, num_stripes=6)
TSModel = TSModel.to(device)

##########################################################
#  set the criterion
triplet_selector_S = Hard(opt.margin1)
criterion_tri_S = OnlineTripletLoss(opt.margin1, triplet_selector_S)

criterion_part_S = nn.CrossEntropyLoss()
criterion_part_T = nn.CrossEntropyLoss()

criterion_cosine = torch.nn.CosineSimilarity(dim=0, eps=1e-6)

param_groups = [{'params': TSModel.parameters(), 'lr': 0.01}]

optimizer_ft = optim.SGD(
             param_groups,
             momentum=0.9, weight_decay=5e-4, nesterov=True)

# rule for learning rate
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.1)
Esempio n. 2
0
    save_path = os.path.join(opt.outf, save_filename)
    torch.save(network.cpu().state_dict(), save_path)
    if torch.cuda.is_available:
        network.cuda()


# setting of train phase
model = PCB(class_num=4768)
# load pretrained para without classifier

if use_gpu:
    model = model.cuda()

# set the criterion
triplet_selector = SemihardNegativeTripletSelector(opt.margin)
criterion_tri = OnlineTripletLoss(opt.margin, triplet_selector)

criterion_part = nn.CrossEntropyLoss()
# criterion_part=CrossEntropyLabelSmooth(4768)
criterion_center = CenterLoss(4768)
criterion_focal = FocalLoss(gamma=2)

# updating rule for parameter
ignored_params = list(map(id, model.model.fc.parameters()))
ignored_params += (
    list(map(id, model.classifier0.parameters())) +
    list(map(id, model.classifier1.parameters())) +
    list(map(id, model.classifier2.parameters())) +
    list(map(id, model.classifier3.parameters())) +
    list(map(id, model.classifier4.parameters())) +
    list(map(id, model.classifier5.parameters())) +