Example #1
0
def train_gcn(train_loader, model_siamese, loss_siamese_fn, optimizer_siamese, scheduler_siamese,
              model_gcn, loss_gcn_fn, optimizer_gcn, scheduler_gcn, num_epochs=25):
    global cnt
    since = time.time()
    model_gcn.train(True)
    model_siamese.eval()
    losses = []
    total_loss = 0
    for epoch in range(num_epochs):
        scheduler_siamese.step()
        scheduler_gcn.step()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for batch_idx, (data, target) in enumerate(train_loader):
            target = target if len(target) > 0 else None
            if not type(data) in (tuple, list):
                data = (data,)
            if use_gpu:
                data = tuple(d.cuda() for d in data)
                if target is not None:
                    target = target.cuda()

            optimizer_gcn.zero_grad()

            with torch.no_grad():
                outputs = model_siamese(*data, target)

            outputs, target = model_gcn(*outputs)  # for SGGNN

            if type(outputs) not in (tuple, list):
                outputs = (outputs,)

            loss_inputs = outputs
            if target is not None:
                target = (target,)
                loss_inputs += target

            loss_inputs = tuple(d.cuda() for d in loss_inputs)

            loss_outputs = loss_gcn_fn(*loss_inputs)
            loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
            losses.append(loss.item())
            total_loss += loss.item()
            loss.backward()
            optimizer_gcn.step()
            if batch_idx % 5 == 0:
                print('epoch = %2d  batch_idx = %4d  loss = %.5f' % (epoch, batch_idx, loss))
            # if batch_idx > 0:
            #     break
        save_network(model_gcn, name, 'gcn' + str(epoch))
        save_whole_network(model_gcn, name, 'whole_gcn' + str(epoch))
    time_elapsed = time.time() - since
    print('time = %f' % (time_elapsed))
    save_network(model_gcn, name, 'best_gcn')
    save_whole_network(model_gcn, name, 'whole_best_gcn')
    return model_gcn
Example #2
0
def train_model(train_loader, model, loss_fn, optimizer, num_epochs=25):
    global cnt
    since = time.time()
    model.train()
    # model.eval()
    losses = []
    total_loss = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for batch_idx, (data, target) in enumerate(train_loader):
            target = target if len(target) > 0 else None
            if not type(data) in (tuple, list):
                data = (data,)
            if use_gpu:
                data = tuple(d.cuda() for d in data)
                if target is not None:
                    target = target.cuda()

            optimizer.zero_grad()
            outputs = model(*data)  # for contrastive loss
            # outputs, target = model(*data, target) # for SGGNN

            if type(outputs) not in (tuple, list):
                outputs = (outputs,)

            loss_inputs = outputs
            if target is not None:
                target = (target,)
                loss_inputs += target

            loss_outputs = loss_fn(*loss_inputs)
            loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
            losses.append(loss.item())
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            print('batch_idx = %4d  loss = %f' % (batch_idx, loss))
        if (epoch + 1) % 5 == 0:
            save_network(model, str(epoch + 1))
            save_whole_network(model, str(epoch))
    time_elapsed = time.time() - since
    print('time = %f' % (time_elapsed))
    save_network(model, 'best')
    save_whole_network(model, 'best')
    return model
Example #3
0
def train_model_siamese(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = model.state_dict()
    last_margin = 0.0
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_id_loss = 0.0
            running_verif_loss = 0.0
            running_id_corrects = 0.0
            running_verif_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs, vf_labels, id_labels = data
                now_batch_size, c, h, w = inputs[0].shape
                if now_batch_size < opt.batchsize:  # next epoch
                    continue

                if type(inputs) not in (tuple, list):
                    inputs = (inputs,)
                if type(id_labels) not in (tuple, list):
                    id_labels = (id_labels,)
                if use_gpu:
                    inputs = tuple(d.cuda() for d in inputs)
                    id_labels = tuple(d.cuda() for d in id_labels)
                    if vf_labels is not None:
                        vf_labels = vf_labels.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs1, f1, outputs2, f2, feature, score = model(inputs[0], inputs[1])
                _, id_preds1 = torch.max(outputs1.data, 1)
                _, id_preds2 = torch.max(outputs2.data, 1)
                _, vf_preds = torch.max(score.data, 1)
                loss_id1 = criterion(outputs1, id_labels[0])
                loss_id2 = criterion(outputs2, id_labels[1])
                loss_id = loss_id1 + loss_id2
                loss_verif = criterion(score, vf_labels)
                loss = loss_verif + loss_id
                # loss = loss_verif
                # if opt.net_loss_model == 0:
                #     loss = loss_id + loss_verif
                # elif opt.net_loss_model == 1:
                #     loss = loss_verif
                # elif opt.net_loss_model == 2:
                #     loss = loss_id
                # else:
                #     print('opt.net_loss_model = %s    error !!!' % opt.net_loss_model)
                #     exit()

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                # statistics
                if int(version[0]) > 0 or int(version[2]) > 3:  # for the new version like 0.4.0 and 0.5.0
                    running_id_loss += loss.item()  # * opt.batchsize
                    running_verif_loss += loss_verif.item()  # * opt.batchsize
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_id_loss += loss.data[0]
                    running_verif_loss += loss_verif.data[0]
                running_id_corrects += float(torch.sum(id_preds1 == id_labels[0].data))
                running_id_corrects += float(torch.sum(id_preds2 == id_labels[1].data))
                running_verif_corrects += float(torch.sum(vf_preds == vf_labels))

            datasize = dataset_sizes['train'] // opt.batchsize * opt.batchsize
            epoch_id_loss = running_id_loss / datasize
            epoch_verif_loss = running_verif_loss / datasize
            epoch_id_acc = running_id_corrects / (datasize * 2)
            epoch_verif_acc = running_verif_corrects / datasize

            print('{} Loss_id: {:.4f} Loss_verif: {:.4f}  Acc_id: {:.4f} Verif_Acc: {:.4f} '.format(
                phase, epoch_id_loss, epoch_verif_loss, epoch_id_acc, epoch_verif_acc))

            epoch_acc = (epoch_id_acc + epoch_verif_acc) / 2.0
            epoch_loss = (epoch_id_loss + epoch_verif_loss) / 2.0
            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5 and epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                save_network(model, name, 'best_siamese')
                save_network(model, name, 'best_siamese_' + str(opt.net_loss_model))
                save_whole_network(model, name, 'whole_best_siamese')

            y_loss[phase].append(epoch_id_loss)
            y_err[phase].append(1.0 - epoch_id_acc)
            # deep copy the model

            if epoch % 10 == 9:
                save_network(model, name, epoch)

            draw_curve(epoch)
            last_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    print('best_epoch = %s     best_loss = %s     best_acc = %s' % (best_epoch, best_loss, best_acc))
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, name, 'last_siamese')
    save_network(model, name, 'last_siamese_' + str(opt.net_loss_model))
    save_whole_network(model, name, 'whole_last_siamese')
    return model
Example #4
0
def train_model_triplet(model, model_verif, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1
    last_margin = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_verif_loss = 0.0
            running_corrects = 0.0
            running_verif_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs, labels, pos, neg = data
                now_batch_size, c, h, w = inputs.shape

                if now_batch_size < opt.batchsize:  # next epoch
                    continue

                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    pos = Variable(pos.cuda())
                    neg = Variable(neg.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs, f = model(inputs)
                _, pf = model(pos)
                _, nf = model(neg)
                # pscore = model_verif(pf * f)
                # nscore = model_verif(nf * f)
                pscore = model_verif((pf - f).pow(2))
                nscore = model_verif((nf - f).pow(2))
                # print(pf.requires_grad)
                # loss
                # ---------------------------------
                labels_0 = torch.zeros(now_batch_size).long()
                labels_1 = torch.ones(now_batch_size).long()
                labels_0 = Variable(labels_0.cuda())
                labels_1 = Variable(labels_1.cuda())

                _, preds = torch.max(outputs.data, 1)
                _, p_preds = torch.max(pscore.data, 1)
                _, n_preds = torch.max(nscore.data, 1)
                loss_id = criterion(outputs, labels)
                loss_verif = (criterion(pscore, labels_0) + criterion(nscore, labels_1)) * 0.5 * opt.alpha
                if opt.net_loss_model == 0:
                    loss = loss_id + loss_verif
                elif opt.net_loss_model == 1:
                    loss = loss_verif
                elif opt.net_loss_model == 2:
                    loss = loss_id
                else:
                    print('opt.net_loss_model = %s    error !!!' % opt.net_loss_model)
                    exit()
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                # statistics
                if int(version[0]) > 0 or int(version[2]) > 3:  # for the new version like 0.4.0 and 0.5.0
                    running_loss += loss.item()  # * opt.batchsize
                    running_verif_loss += loss_verif.item()  # * opt.batchsize
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_loss += loss.data[0]
                    running_verif_loss += loss_verif.data[0]
                running_corrects += float(torch.sum(preds == labels.data))
                running_verif_corrects += float(torch.sum(p_preds == 0)) + float(torch.sum(n_preds == 1))

            datasize = dataset_sizes['train'] // opt.batchsize * opt.batchsize
            epoch_loss = running_loss / datasize
            epoch_verif_loss = running_verif_loss / datasize
            epoch_acc = running_corrects / datasize
            epoch_verif_acc = running_verif_corrects / (2 * datasize)

            print('{} Loss: {:.4f} Loss_verif: {:.4f}  Acc: {:.4f} Verif_Acc: {:.4f} '.format(
                phase, epoch_loss, epoch_verif_loss, epoch_acc, epoch_verif_acc))
            # if phase == 'val':
            #     if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5 and epoch_loss < best_loss):
            #         best_acc = epoch_acc
            #         best_loss = epoch_loss
            #         best_epoch = epoch
            #         best_model_wts = model.state_dict()
            #     if epoch >= 0:
            #         save_network(model, name, epoch)

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            epoch_acc = (epoch_acc + epoch_verif_acc) / 2.0
            epoch_loss = (epoch_loss + epoch_verif_loss) / 2.0
            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5 and epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                save_network(model, name, 'best')

            if epoch % 10 == 9:
                save_network(model, name, epoch)
            draw_curve(epoch)
            last_model_wts = model.state_dict()

        print()

    time_elapsed = time.time() - since
    print('best_epoch = %s     best_loss = %s     best_acc = %s' % (best_epoch, best_loss, best_acc))
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, name, 'last')
    return model
def train(model,
          criterion_identify,
          criterion_reconstruct,
          criterion_contrastive,
          optimizer,
          scheduler,
          num_epochs=25):
    since = time.time()
    best_acc = 0.0
    best_loss = 10000.0
    best_epoch = -1
    cnt = 0

    # Cross-dataset fine-tune
    # best for Market, evaluation for Duke, CUHK03 and MSMT17
    # r_id = 0.3
    # r_rec = 0.3
    # r_con = 1.0
    # r_s = 0.5
    # r_c = 0.2
    # w_main_c = 0.7
    # w_sketch_c = 0
    # w_main_s = 0.9
    # w_content_s = 0

    # Cross-dataset fine-tune
    # best for Duke, evaluation for Market
    r_id = 0.3
    r_rec = 0.2
    r_con = 1.0
    r_s = 0.6
    r_c = 0.2
    w_main_c = 0.7
    w_sketch_c = 0
    w_main_s = 0.95
    w_content_s = 0

    # r_id = 0  # for ablation, disable Id loss
    # r_c = 0  # for ablation, disable contrastive loss
    # r_s = 0  # for ablation, disable contrastive loss
    w_content_c = 1 - w_main_c - 1e-5
    w_sketch_s = 1 - w_main_s - 1e-5
    print(
        'r_id = %.3f   r_rec = %.3f   r_con = %.3f   r_s = %.3f   r_c = %.3f' %
        (r_id, r_rec, r_con, r_s, r_c))
    print('w_main_c = %.3f   w_content_c = %.3f   w_sketch_c = %.3f' %
          (w_main_c, w_content_c, w_sketch_c))
    print('w_main_s = %.3f   w_content_s = %.3f   w_sketch_s = %.3f' %
          (w_main_s, w_content_s, w_sketch_s))

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            scheduler.step()
            model.train(True)  # Set model to training mode

            running_loss = 0.0
            running_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs1, inputs2, id_labels1, id_labels2 = data
                # get the soft-label for feature disentangling
                id_labels_content_1 = get_soft_label_6domain(
                    id_labels1,
                    w_main=w_main_c,
                    w_content=w_content_c,
                    w_sketch=w_sketch_c,
                    domain_num=opt.domain_num)
                id_labels_content_2 = get_soft_label_6domain(
                    id_labels2,
                    w_main=w_main_c,
                    w_content=w_content_c,
                    w_sketch=w_sketch_c,
                    domain_num=opt.domain_num)
                id_labels_sketch_1 = get_soft_label_6domain(
                    id_labels1,
                    w_main=w_main_s,
                    w_content=w_content_s,
                    w_sketch=w_sketch_s,
                    domain_num=opt.domain_num)
                id_labels_sketch_2 = get_soft_label_6domain(
                    id_labels2,
                    w_main=w_main_s,
                    w_content=w_content_s,
                    w_sketch=w_sketch_s,
                    domain_num=opt.domain_num)

                now_batch_size, c, h, w = inputs1.shape
                if now_batch_size < opt.batchsize:  # next epoch
                    continue

                if use_gpu:
                    inputs1 = inputs1.cuda()
                    inputs2 = inputs2.cuda()
                    id_labels_content_1 = id_labels_content_1.cuda()
                    id_labels_content_2 = id_labels_content_2.cuda()
                    id_labels_sketch_1 = id_labels_sketch_1.cuda()
                    id_labels_sketch_2 = id_labels_sketch_2.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                output_content_1, output_content_2, output_sketch_1, output_sketch_2, \
                rec_img_cs11, rec_img_cs12, rec_img_cs21, rec_img_cs22, \
                feature_coder_c1, feature_coder_c2, feature_coder_s1, feature_coder_s2 = model(inputs1, inputs2)

                _, id_preds_content_1 = torch.max(output_content_1.detach(), 1)
                _, id_preds_content_2 = torch.max(output_content_2.detach(), 1)
                _, id_preds_sketch_1 = torch.max(output_sketch_1.detach(), 1)
                _, id_preds_sketch_2 = torch.max(output_sketch_2.detach(), 1)
                loss_id = 0
                loss_id += criterion_identify(output_content_1,
                                              id_labels_content_1)
                loss_id += criterion_identify(output_content_2,
                                              id_labels_content_2)
                loss_id += criterion_identify(output_sketch_1,
                                              id_labels_sketch_1)
                loss_id += criterion_identify(output_sketch_2,
                                              id_labels_sketch_2)
                loss_rec = 0
                loss_rec += criterion_reconstruct(rec_img_cs11, inputs1)
                loss_rec += criterion_reconstruct(rec_img_cs12, inputs1)
                loss_rec += criterion_reconstruct(rec_img_cs21, inputs2)
                loss_rec += criterion_reconstruct(rec_img_cs22, inputs2)
                loss_c, loss_s = criterion_contrastive(feature_coder_c1,
                                                       feature_coder_c2,
                                                       feature_coder_s1,
                                                       feature_coder_s2)
                loss_con = r_s * loss_s + r_c * loss_c
                # calculate the total loss
                loss = r_id * loss_id + r_rec * loss_rec + loss_con

                if cnt % 200 == 0:
                    print(
                        'cnt = %5d   loss   = %.4f  loss_id = %.4f  loss_rec = %.4f  loss_con = %.4f'
                        % (cnt, loss.cpu().detach().numpy(),
                           loss_id.cpu().detach().numpy(),
                           loss_rec.cpu().detach().numpy(),
                           loss_con.cpu().detach().numpy()))
                    print('loss_c = %.4f  loss_s  = %.4f' %
                          (loss_c.cpu().detach().numpy(),
                           loss_s.cpu().detach().numpy()))
                cnt += 1

                loss.backward()
                optimizer.step()
                # statistics
                running_loss += loss.item()  # * opt.batchsize
                running_corrects += float(
                    torch.sum(id_preds_content_1 == id_labels_content_1.argmax(
                        1).detach()))
                running_corrects += float(
                    torch.sum(id_preds_content_2 == id_labels_content_2.argmax(
                        1).detach()))
                running_corrects += float(
                    torch.sum(id_preds_sketch_1 == id_labels_sketch_1.argmax(
                        1).detach()))
                running_corrects += float(
                    torch.sum(id_preds_sketch_2 == id_labels_sketch_2.argmax(
                        1).detach()))

            datasize = dataset_sizes[phase] // opt.batchsize * opt.batchsize
            epoch_loss = running_loss / datasize
            epoch_acc = running_corrects / (datasize * 4)

            print('{} Loss: {:.4f}  Acc: {:.4f} '.format(
                phase, epoch_loss, epoch_acc))
            if epoch_acc > best_acc or (np.fabs(epoch_acc - best_acc) < 1e-5
                                        and epoch_loss < best_loss):
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                save_network(model, name,
                             'best' + '_' + str(opt.net_loss_model))

            save_network(model, name, epoch)

    time_elapsed = time.time() - since
    print('best_epoch = %s     best_loss = %s     best_acc = %s' %
          (best_epoch, best_loss, best_acc))
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    save_network(model, name, 'last' + '_' + str(opt.net_loss_model))