コード例 #1
0
ファイル: train.py プロジェクト: flychen321/reid_verify
def train_gcn(train_loader, model_siamese, loss_siamese_fn, optimizer_siamese, scheduler_siamese,
              model_gcn, loss_gcn_fn, optimizer_gcn, scheduler_gcn, num_epochs=25):
    global cnt
    since = time.time()
    model_gcn.train(True)
    model_siamese.eval()
    losses = []
    total_loss = 0
    for epoch in range(num_epochs):
        scheduler_siamese.step()
        scheduler_gcn.step()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for batch_idx, (data, target) in enumerate(train_loader):
            target = target if len(target) > 0 else None
            if not type(data) in (tuple, list):
                data = (data,)
            if use_gpu:
                data = tuple(d.cuda() for d in data)
                if target is not None:
                    target = target.cuda()

            optimizer_gcn.zero_grad()

            with torch.no_grad():
                outputs = model_siamese(*data, target)

            outputs, target = model_gcn(*outputs)  # for SGGNN

            if type(outputs) not in (tuple, list):
                outputs = (outputs,)

            loss_inputs = outputs
            if target is not None:
                target = (target,)
                loss_inputs += target

            loss_inputs = tuple(d.cuda() for d in loss_inputs)

            loss_outputs = loss_gcn_fn(*loss_inputs)
            loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
            losses.append(loss.item())
            total_loss += loss.item()
            loss.backward()
            optimizer_gcn.step()
            if batch_idx % 5 == 0:
                print('epoch = %2d  batch_idx = %4d  loss = %.5f' % (epoch, batch_idx, loss))
            # if batch_idx > 0:
            #     break
        save_network(model_gcn, name, 'gcn' + str(epoch))
        save_whole_network(model_gcn, name, 'whole_gcn' + str(epoch))
    time_elapsed = time.time() - since
    print('time = %f' % (time_elapsed))
    save_network(model_gcn, name, 'best_gcn')
    save_whole_network(model_gcn, name, 'whole_best_gcn')
    return model_gcn
コード例 #2
0
ファイル: train.py プロジェクト: flychen321/siamese_gcn
def train_model(train_loader, model, loss_fn, optimizer, num_epochs=25):
    global cnt
    since = time.time()
    model.train()
    # model.eval()
    losses = []
    total_loss = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for batch_idx, (data, target) in enumerate(train_loader):
            target = target if len(target) > 0 else None
            if not type(data) in (tuple, list):
                data = (data,)
            if use_gpu:
                data = tuple(d.cuda() for d in data)
                if target is not None:
                    target = target.cuda()

            optimizer.zero_grad()
            outputs = model(*data)  # for contrastive loss
            # outputs, target = model(*data, target) # for SGGNN

            if type(outputs) not in (tuple, list):
                outputs = (outputs,)

            loss_inputs = outputs
            if target is not None:
                target = (target,)
                loss_inputs += target

            loss_outputs = loss_fn(*loss_inputs)
            loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
            losses.append(loss.item())
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            print('batch_idx = %4d  loss = %f' % (batch_idx, loss))
        if (epoch + 1) % 5 == 0:
            save_network(model, str(epoch + 1))
            save_whole_network(model, str(epoch))
    time_elapsed = time.time() - since
    print('time = %f' % (time_elapsed))
    save_network(model, 'best')
    save_whole_network(model, 'best')
    return model
コード例 #3
0
ファイル: train.py プロジェクト: flychen321/reid_verify
def train_model_siamese(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = model.state_dict()
    last_margin = 0.0
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_id_loss = 0.0
            running_verif_loss = 0.0
            running_id_corrects = 0.0
            running_verif_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs, vf_labels, id_labels = data
                now_batch_size, c, h, w = inputs[0].shape
                if now_batch_size < opt.batchsize:  # next epoch
                    continue

                if type(inputs) not in (tuple, list):
                    inputs = (inputs,)
                if type(id_labels) not in (tuple, list):
                    id_labels = (id_labels,)
                if use_gpu:
                    inputs = tuple(d.cuda() for d in inputs)
                    id_labels = tuple(d.cuda() for d in id_labels)
                    if vf_labels is not None:
                        vf_labels = vf_labels.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs1, f1, outputs2, f2, feature, score = model(inputs[0], inputs[1])
                _, id_preds1 = torch.max(outputs1.data, 1)
                _, id_preds2 = torch.max(outputs2.data, 1)
                _, vf_preds = torch.max(score.data, 1)
                loss_id1 = criterion(outputs1, id_labels[0])
                loss_id2 = criterion(outputs2, id_labels[1])
                loss_id = loss_id1 + loss_id2
                loss_verif = criterion(score, vf_labels)
                loss = loss_verif + loss_id
                # loss = loss_verif
                # if opt.net_loss_model == 0:
                #     loss = loss_id + loss_verif
                # elif opt.net_loss_model == 1:
                #     loss = loss_verif
                # elif opt.net_loss_model == 2:
                #     loss = loss_id
                # else:
                #     print('opt.net_loss_model = %s    error !!!' % opt.net_loss_model)
                #     exit()

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                # statistics
                if int(version[0]) > 0 or int(version[2]) > 3:  # for the new version like 0.4.0 and 0.5.0
                    running_id_loss += loss.item()  # * opt.batchsize
                    running_verif_loss += loss_verif.item()  # * opt.batchsize
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_id_loss += loss.data[0]
                    running_verif_loss += loss_verif.data[0]
                running_id_corrects += float(torch.sum(id_preds1 == id_labels[0].data))
                running_id_corrects += float(torch.sum(id_preds2 == id_labels[1].data))
                running_verif_corrects += float(torch.sum(vf_preds == vf_labels))

            datasize = dataset_sizes['train'] // opt.batchsize * opt.batchsize
            epoch_id_loss = running_id_loss / datasize
            epoch_verif_loss = running_verif_loss / datasize
            epoch_id_acc = running_id_corrects / (datasize * 2)
            epoch_verif_acc = running_verif_corrects / datasize

            print('{} Loss_id: {:.4f} Loss_verif: {:.4f}  Acc_id: {:.4f} Verif_Acc: {:.4f} '.format(
                phase, epoch_id_loss, epoch_verif_loss, epoch_id_acc, epoch_verif_acc))

            epoch_acc = (epoch_id_acc + epoch_verif_acc) / 2.0
            epoch_loss = (epoch_id_loss + epoch_verif_loss) / 2.0
            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5 and epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                save_network(model, name, 'best_siamese')
                save_network(model, name, 'best_siamese_' + str(opt.net_loss_model))
                save_whole_network(model, name, 'whole_best_siamese')

            y_loss[phase].append(epoch_id_loss)
            y_err[phase].append(1.0 - epoch_id_acc)
            # deep copy the model

            if epoch % 10 == 9:
                save_network(model, name, epoch)

            draw_curve(epoch)
            last_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    print('best_epoch = %s     best_loss = %s     best_acc = %s' % (best_epoch, best_loss, best_acc))
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, name, 'last_siamese')
    save_network(model, name, 'last_siamese_' + str(opt.net_loss_model))
    save_whole_network(model, name, 'whole_last_siamese')
    return model
コード例 #4
0
def train_with_softlabel(model,
                         criterion_soft,
                         optimizer,
                         scheduler,
                         num_epochs=25):
    since = time.time()
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs, id_labels = data
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < opt.batchsize:  # next epoch
                    continue

                id_labels_soft = get_soft_label_lsr(id_labels)

                if use_gpu:
                    inputs = inputs.cuda()
                    id_labels_soft = id_labels_soft.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                output = model(inputs)[0]
                _, id_preds = torch.max(output.detach(), 1)
                loss = criterion_soft(output, id_labels_soft)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                # statistics
                running_loss += loss.item()  # * opt.batchsize
                # running_corrects += float(torch.sum(id_preds == id_labels.detach()))
                running_corrects += float(
                    torch.sum(id_preds == id_labels_soft.argmax(1).detach()))

            datasize = dataset_sizes[phase] // opt.batchsize * opt.batchsize
            epoch_loss = running_loss / datasize
            epoch_acc = running_corrects / datasize

            print('{} Loss: {:.4f}  Acc: {:.4f} '.format(
                phase, epoch_loss, epoch_acc))
            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5
                                        and epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                save_whole_network(model, name,
                                   'best' + '_' + str(opt.net_loss_model))

            if epoch % 10 == 9:
                save_whole_network(model, name, epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

    print('best_epoch = %s     best_loss = %s     best_acc = %s' %
          (best_epoch, best_loss, best_acc))
    save_whole_network(model, name, 'last' + '_' + str(opt.net_loss_model))
    return model
コード例 #5
0
def train(model, criterion_triplet, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            triplet_running_loss = 0.0
            triplet_running_corrects = 0.0
            running_margin = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                anchors, labels, pos = data
                now_batch_size, c, h, w = anchors.shape
                if now_batch_size < opt.batchsize:  # next epoch
                    continue

                pos = pos.view(-1, c, h, w)
                # copy pos labels 4times
                pos_labels = labels.repeat(4).reshape(4, now_batch_size)
                pos_labels = pos_labels.transpose(0, 1).reshape(4 *
                                                                now_batch_size)
                if use_gpu:
                    anchors = anchors.cuda()
                    pos = pos.cuda()
                    labels = labels.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                anchor_features = model(anchors)
                pos_features = model(pos)

                f = anchor_features
                pf = pos_features
                neg_labels = pos_labels
                # hard-neg
                # ----------------------------------
                nf_data = pf  # 128*512
                rand = np.random.permutation(4 * now_batch_size)
                nf_data = nf_data[rand, :]
                neg_labels = neg_labels[rand]
                nf_t = nf_data.transpose(0, 1)  # 512*64
                score = torch.mm(f.data, nf_t)  # cosine 16*64
                score, rank = score.sort(dim=1,
                                         descending=True)  # score high == hard
                labels_cpu = labels.cpu()
                nf_hard = torch.zeros(f.shape).cuda()
                for k in range(now_batch_size):
                    hard = rank[k, :]
                    for kk in hard:
                        now_label = neg_labels[kk]
                        anchor_label = labels_cpu[k]
                        if now_label != anchor_label:
                            nf_hard[k, :] = nf_data[kk, :]
                            break

                # hard-pos
                # ----------------------------------
                pf_hard = torch.zeros(f.shape).cuda()  # 16*512
                for k in range(now_batch_size):
                    pf_data = pf[4 * k:4 * k + 4, :]
                    pf_t = pf_data.transpose(0, 1)  # 512*4
                    ff = f.data[k, :].reshape(1, -1)  # 1*512
                    score = torch.mm(ff, pf_t)  # cosine
                    score, rank = score.sort(
                        dim=1, descending=False)  # score low == hard
                    pf_hard[k, :] = pf_data[rank[0][0], :]

                # loss
                # ---------------------------------
                pscore = torch.sum(f * pf_hard, dim=1)
                nscore = torch.sum(f * nf_hard, dim=1)
                loss_triplet = criterion_triplet(f, pf_hard, nf_hard)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss_triplet.backward()
                    optimizer.step()
                # statistics
                triplet_running_loss += loss_triplet.item()  # * opt.batchsize
                triplet_running_corrects += float(
                    torch.sum(pscore > nscore + 0.5))
                running_margin += float(torch.sum(pscore - nscore))

            datasize = dataset_sizes[phase] // opt.batchsize * now_batch_size
            triplet_epoch_loss = triplet_running_loss / datasize
            triplet_epoch_acc = triplet_running_corrects / datasize
            epoch_margin = running_margin / datasize
            epoch_acc = triplet_epoch_acc

            print(
                '{} triplet_epoch_loss: {:.4f} triplet_epoch_acc: {:.4f} MeanMargin: {:.4f}'
                .format(phase, triplet_epoch_loss, triplet_epoch_acc,
                        epoch_margin))

            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5
                                        and triplet_epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = triplet_epoch_loss
                best_epoch = epoch
                save_whole_network(model, name,
                                   'best' + '_' + str(opt.net_loss_model))

            if epoch % 10 == 9:
                save_whole_network(model, name, epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

    print('best_epoch = %s     best_loss = %s     best_acc = %s' %
          (best_epoch, best_loss, best_acc))
    save_whole_network(model, name, 'last' + '_' + str(opt.net_loss_model))
    return model
コード例 #6
0
def train(model,
          criterion_identify,
          criterion_contrastive_same,
          criterion_contrastive_diff,
          criterion_orthogonal,
          optimizer,
          scheduler,
          num_epochs=25):
    since = time.time()
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1
    cnt = 0

    # for ResNet-50
    # Hyper-parameters about Two-level Classification Label Assignment Strategy
    w_main_c = 1.0
    w_main_mix_c = 0.0
    # Weights of Two-level Classification Loss Functions
    r_id_classify = 0.5
    r_id_mix = 0.3
    # Weights of the Structural Consistencies
    r_d = 0.02
    r_t = 0.01
    r_o = 0.05

    print(
        'r_id_classify = %.3f  r_id_mix = %.3f  r_d = %.3f  r_t = %.3f  r_o = %.3f'
        % (r_id_classify, r_id_mix, r_d, r_t, r_o))
    print('w_main_c = %.3f    w_main_mix_c = %.3f' % (w_main_c, w_main_mix_c))

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            scheduler.step()
            model.train(True)  # Set model to training mode

            running_loss = 0.0
            running_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs1_0, inputs1_1, inputs2_0, inputs2_1, inputs3_0, inputs3_1, \
                label1_0, label1_1, label2_0, label2_1, label3_0, label3_1 = data
                inputs = torch.cat((inputs1_0, inputs1_1, inputs2_0, inputs2_1,
                                    inputs3_0, inputs3_1), 0)
                id_labels = torch.cat((label1_0, label1_1, label2_0, label2_1,
                                       label3_0, label3_1), 0)
                # Two-Level Labels
                id_labels_soft = get_soft_label_6domain(id_labels %
                                                        opt.class_base,
                                                        w_main=w_main_c)
                id_labels_mix_soft = get_soft_label_6domain(
                    id_labels % opt.class_base, w_main=w_main_mix_c)
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size // 6 < opt.batchsize:  # next epoch
                    print('continue')
                    continue

                if use_gpu:
                    inputs = inputs.cuda()
                    id_labels_soft = id_labels_soft.cuda()
                    id_labels_mix_soft = id_labels_mix_soft.cuda()
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                outputs, features = model(inputs)
                mask = torch.FloatTensor(outputs.shape[0]).zero_().cuda()
                for d in range(opt.domain_num):
                    mask[now_batch_size * d: now_batch_size * (d + 1)] \
                        = (id_labels / opt.class_base >= d) * (id_labels / opt.class_base < (d + 1))
                _, id_preds = torch.max(outputs.detach(), 1)
                id_labels_soft_all = id_labels_soft
                id_labels_mix_soft_all = id_labels_mix_soft
                for d in range(opt.domain_num - 1):
                    id_labels_soft_all = torch.cat(
                        (id_labels_soft_all, id_labels_soft), 0)
                    id_labels_mix_soft_all = torch.cat(
                        (id_labels_mix_soft_all, id_labels_mix_soft), 0)
                # Intra-domain Category Classification Loss
                loss_id_classify = criterion_identify(outputs,
                                                      id_labels_soft_all, mask)
                # Domain Classification Loss
                loss_id_mix_classify = criterion_identify(
                    outputs, id_labels_mix_soft_all, (1 - mask))
                loss_id = r_id_classify * loss_id_classify + r_id_mix * loss_id_mix_classify
                features = features[:now_batch_size]
                part_len = features.shape[0] // 6

                feature1_0 = features[part_len * 0:part_len * 1]
                feature1_1 = features[part_len * 1:part_len * 2]
                feature2_0 = features[part_len * 2:part_len * 3]
                feature2_1 = features[part_len * 3:part_len * 4]
                feature3_0 = features[part_len * 4:part_len * 5]
                feature3_1 = features[part_len * 5:part_len * 6]

                loss_con_dist = 0
                loss_con_topol = 0
                loss_con_orth = 0

                # Inter domain
                # Inter-Domain Distance Consistency Loss
                loss_con_dist += criterion_contrastive_same(
                    (feature1_0 - feature1_1), (feature2_0 - feature2_1))
                loss_con_dist += criterion_contrastive_same(
                    (feature1_0 - feature1_1), (feature3_0 - feature3_1))
                loss_con_dist += criterion_contrastive_same(
                    (feature2_0 - feature2_1), (feature3_0 - feature3_1))

                # Intra domain
                # Cross-domain Topology Consistency Loss
                loss_con_topol += criterion_contrastive_same(
                    (feature1_0 - feature2_0), (feature1_1 - feature2_1))
                loss_con_topol += criterion_contrastive_same(
                    (feature1_0 - feature3_0), (feature1_1 - feature3_1))
                loss_con_topol += criterion_contrastive_same(
                    (feature2_0 - feature3_0), (feature2_1 - feature3_1))

                # Cross-domain Orthogonality Loss
                loss_con_orth += criterion_orthogonal(feature1_0, feature1_1)
                loss_con_orth += criterion_orthogonal(feature2_0, feature2_1)
                loss_con_orth += criterion_orthogonal(feature3_0, feature3_1)

                loss_con = r_d * loss_con_dist + r_t * loss_con_topol + r_o * loss_con_orth

                # calculate the total loss
                loss = loss_id + loss_con
                if cnt % 200 == 0:
                    print(
                        'cnt = %5d   loss = %.4f  loss_id = %.4f  loss_con = %.4f'
                        % (cnt, loss.cpu().detach().numpy(),
                           loss_id.cpu().detach().numpy(),
                           loss_con.cpu().detach().numpy()))
                    print(
                        'loss_con_dist  = %.4f  loss_con_topol  = %.4f  loss_con_orth  = %.4f'
                        % (loss_con_dist.cpu().detach().numpy(),
                           loss_con_topol.cpu().detach().numpy(),
                           loss_con_orth.cpu().detach().numpy()))
                    print(
                        'loss_id_classify = %.4f  loss_id_mix_classify = %.4f'
                        % (loss_id_classify.cpu().detach().numpy(),
                           loss_id_mix_classify.cpu().detach().numpy()))

                cnt += 1
                loss.backward()
                optimizer.step()
                # statistics
                running_loss += loss.item()  # * opt.batchsize
                running_corrects += float(
                    torch.sum(
                        (id_preds == id_labels_soft_all.argmax(1).detach() *
                         mask.long())))
            datasize = dataset_sizes[phase] // opt.batchsize * opt.batchsize
            epoch_loss = running_loss / datasize
            epoch_acc = running_corrects / (datasize * 6)

            print('{} Loss: {:.4f}  Acc_id: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5
                                        and epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                save_whole_network(model, name,
                                   'best' + '_' + str(opt.net_loss_model))

            save_whole_network(model, name, epoch)
            time_elapsed = time.time() - since
            print('Training complete in {:.0f}m {:.0f}s'.format(
                time_elapsed // 60, time_elapsed % 60))

    print('best_epoch = %s     best_loss = %s     best_acc = %s' %
          (best_epoch, best_loss, best_acc))
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    save_whole_network(model, name, 'last' + '_' + str(opt.net_loss_model))
コード例 #7
0
def train(model,
          criterion_contrastive,
          criterion_verify,
          optimizer,
          scheduler,
          num_epochs=25):
    since = time.time()
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                (inputs1, inputs2), siamese_labels = data
                now_batch_size, c, h, w = inputs1.shape
                if now_batch_size < opt.batchsize:  # next epoch
                    continue

                if use_gpu:
                    inputs1 = inputs1.cuda()
                    inputs2 = inputs2.cuda()
                    siamese_labels = siamese_labels.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                feature1, feature2, output = model(inputs1, inputs2)
                _, verify_preds = torch.max(output.detach(), 1)

                # #representation leaarning
                loss_verify = criterion_verify(output, siamese_labels)
                # #metric learning
                loss_contrastive = criterion_contrastive(
                    feature1, feature2, siamese_labels)

                loss = loss_verify + 0 * loss_contrastive

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                # statistics
                running_loss += loss.item()  # * opt.batchsize
                running_corrects += float(
                    torch.sum(verify_preds == siamese_labels.detach()))

            datasize = dataset_sizes[phase] // opt.batchsize * opt.batchsize
            epoch_loss = running_loss / datasize
            epoch_acc = running_corrects / datasize

            print('{} Loss: {:.4f}  Acc: {:.4f} '.format(
                phase, epoch_loss, epoch_acc))
            print('loss_verify: {:.4f}  loss_contrastive: {:.4f} '.format(
                loss_verify, loss_contrastive))
            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5
                                        and epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                save_whole_network(model, name,
                                   'best' + '_' + str(opt.net_loss_model))

            if epoch % 10 == 9:
                save_whole_network(model, name, epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

    print('best_epoch = %s     best_loss = %s     best_acc = %s' %
          (best_epoch, best_loss, best_acc))
    save_whole_network(model, name, 'last' + '_' + str(opt.net_loss_model))
    return model