예제 #1
0
def main() -> None:
    model = Model()
    optimizer = SGD(model.parameters(),
                    lr=0.001,
                    momentum=0.9,
                    weight_decay=1e-5)
    train_loader = get_loader(is_train=True, batch_size=64)
    val_loader = get_loader(is_train=False, batch_size=1000)
    criterion = CircleLoss(m=0.25, gamma=80)

    for epoch in range(20):
        for img, label in tqdm(train_loader):
            model.zero_grad()
            pred = model(img)
            loss = criterion(*convert_label_to_similarity(pred, label))
            loss.backward()
            optimizer.step()

    all_features = []
    all_labels = []
    for img, label in val_loader:
        pred = model(img)
        all_features.append(pred.data.numpy())
        all_labels.append(label.data.numpy())
    all_features = np.concatenate(all_features, 0)
    all_labels = np.concatenate(all_labels, 0)
    plot_features(all_features, all_labels, 10)
예제 #2
0
def main(args, resume: bool = True) -> None:
    logger = logging.getLogger('tl')
    saved_model = os.path.join(args.tl_outdir, "resume.state")

    model = Model().cuda()
    optimizer = SGD(model.parameters(),
                    lr=0.001,
                    momentum=0.9,
                    weight_decay=1e-5)
    train_loader = get_loader(datadir=args.datadir,
                              is_train=True,
                              batch_size=64)
    val_loader = get_loader(datadir=args.datadir, is_train=False, batch_size=2)
    criterion = CircleLoss(m=0.25, gamma=80)

    if args.tl_resume and os.path.exists("resume.state"):
        model.load_state_dict(torch.load("resume.state"))
    else:
        counter = 0
        for epoch in range(100000):
            logger.info(f'Epoch {epoch}')
            pbar = tqdm(train_loader, desc=f'{args.tl_time_str}')
            for step, (img, label) in enumerate(pbar):
                img = img.cuda()
                label = label.cuda()
                model.zero_grad()
                pred = model(img)
                sp, sn = convert_label_to_similarity(pred, label)
                loss = criterion(sp, sn)
                loss.backward()
                optimizer.step()
                if counter % 10 == 0:
                    summary_dicts = collections.defaultdict(dict)
                    summary_dicts['sp_sn']['sp_mean'] = sp.mean().item()
                    summary_dicts['sp_sn']['sn_mean'] = sn.mean().item()
                    summary_dicts['loss']['loss'] = loss.item()
                    summary_defaultdict2txtfig(default_dict=summary_dicts,
                                               prefix='train',
                                               step=counter,
                                               textlogger=global_textlogger,
                                               save_fig_sec=90)
                counter += 1
            recal, pre, (tp, fp, fn, tn) = validate(val_loader=val_loader,
                                                    model=model)
            summary_dicts = collections.defaultdict(dict)
            summary_dicts['recal']['recal'] = recal
            summary_dicts['pre']['pre'] = pre
            summary_dicts['tp_fp_fn_tn']['tp'] = tp
            summary_dicts['tp_fp_fn_tn']['fp'] = fp
            summary_dicts['tp_fp_fn_tn']['fn'] = fn
            summary_dicts['tp_fp_fn_tn']['tn'] = tn
            summary_defaultdict2txtfig(default_dict=summary_dicts,
                                       prefix='val',
                                       step=epoch,
                                       textlogger=global_textlogger,
                                       save_fig_sec=90)
            if args.tl_debug: break
        torch.save(model.state_dict(), saved_model)
예제 #3
0
def train(model, criterion, optimizer, epoch, loader, device):
    print("Training... Epoch = %d" % epoch)
    for img, label in tqdm(loader):
        img, label = img.to(device), label.to(device)
        model.zero_grad()
        pred = model(img)
        loss = criterion(*convert_label_to_similarity(pred, label))
        loss.backward()
        optimizer.step()
예제 #4
0
def main() -> None:
    model = Model()
    classifier = Classifier()
    optimizer = SGD(model.parameters(),
                    lr=0.001,
                    momentum=0.9,
                    weight_decay=1e-5)
    optimizer_cls = SGD(classifier.parameters(),
                        lr=0.001,
                        momentum=0.9,
                        weight_decay=1e-5)
    train_loader = get_loader(is_train=True, batch_size=64)
    val_loader = get_loader(is_train=False, batch_size=1000)
    criterion = CircleLoss(m=0.25, gamma=80)
    criterion_xe = nn.CrossEntropyLoss()

    for epoch in range(20):
        for img, label in train_loader:
            model.zero_grad()
            features = model(img)
            loss = criterion(*convert_label_to_similarity(features, label))
            loss.backward()
            optimizer.step()
        print('[{}/{}] Training with Circle Loss.'.format(epoch + 1, 20))

    for epoch in range(20):
        for img, label in train_loader:
            model.zero_grad()
            classifier.zero_grad()
            features = model(img)
            output = classifier(features)
            loss = criterion_xe(output, label)
            loss.backward()
            optimizer_cls.step()
        print('[{}/{}] Training classifier.'.format(epoch + 1, 20))

        correct = 0
        for img, label in val_loader:
            features = model(img)
            output = classifier(features)
            pred = output.data.max(1)[1]
            correct += pred.eq(label.data).cpu().sum()
        print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
            correct, len(val_loader.dataset),
            100. * correct / len(val_loader.dataset)))
def main() -> None:
    model = Model()
    optimizer = SGD(model.parameters(),
                    lr=0.001,
                    momentum=0.9,
                    weight_decay=1e-5)
    train_loader = get_loader(is_train=True, batch_size=64)
    val_loader = get_loader(is_train=False, batch_size=2)
    criterion = CircleLoss(m=0.25, gamma=80)

    for epoch in range(20):
        for img, label in tqdm(train_loader):
            model.zero_grad()
            pred = model(img)
            loss = criterion(*convert_label_to_similarity(pred, label))
            loss.backward()
            optimizer.step()

    thresh = 0.75
    for img, label in val_loader:
        pred = model(img)
        pred_label = torch.sum(pred[0] * pred[1]) > thresh
        plot(img[0, 0].data.numpy(), img[1, 0].data.numpy(), pred_label)
        break
예제 #6
0
def train_model(model, model_test, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # best_model_wts = model.state_dict()
    # best_acc = 0.0
    warm_up = 0.1  # We start from the 0.1*lrRate
    warm_iteration = round(dataset_sizes['satellite'] / opt.batchsize) * opt.warm_epoch  # first 5 epoch

    if opt.arcface:
        criterion_arcface = losses.ArcFaceLoss(num_classes=opt.nclasses, embedding_size=512)
    if opt.cosface:
        criterion_cosface = losses.CosFaceLoss(num_classes=opt.nclasses, embedding_size=512)
    if opt.circle:
        criterion_circle = CircleLoss(m=0.25, gamma=32)  # gamma = 64 may lead to a better result.
    if opt.triplet:
        miner = miners.MultiSimilarityMiner()
        criterion_triplet = losses.TripletMarginLoss(margin=0.3)
    if opt.lifted:
        criterion_lifted = losses.GeneralizedLiftedStructureLoss(neg_margin=1, pos_margin=0)
    if opt.contrast:
        criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
    if opt.sphere:
        criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4)

    for epoch in range(num_epochs - start_epoch):
        epoch = epoch + start_epoch
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            running_corrects2 = 0.0
            running_corrects3 = 0.0
            # Iterate over data.
            for data, data2, data3, data4 in zip(dataloaders['satellite'], dataloaders['street'], dataloaders['drone'],
                                                 dataloaders['google']):
                # get the inputs
                inputs, labels = data
                inputs2, labels2 = data2
                inputs3, labels3 = data3
                inputs4, labels4 = data4
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < opt.batchsize:  # skip the last batch
                    continue
                if use_gpu:
                    inputs = Variable(inputs.cuda().detach())
                    inputs2 = Variable(inputs2.cuda().detach())
                    inputs3 = Variable(inputs3.cuda().detach())
                    labels = Variable(labels.cuda().detach())
                    labels2 = Variable(labels2.cuda().detach())
                    labels3 = Variable(labels3.cuda().detach())
                    if opt.extra_Google:
                        inputs4 = Variable(inputs4.cuda().detach())
                        labels4 = Variable(labels4.cuda().detach())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                if phase == 'val':
                    with torch.no_grad():
                        outputs, outputs2 = model(inputs, inputs2)
                else:
                    if opt.views == 2:
                        outputs, outputs2 = model(inputs, inputs2)
                    elif opt.views == 3:
                        if opt.extra_Google:
                            outputs, outputs2, outputs3, outputs4 = model(inputs, inputs2, inputs3, inputs4)
                        else:
                            outputs, outputs2, outputs3 = model(inputs, inputs2, inputs3)

                return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere

                if opt.views == 2:
                    _, preds = torch.max(outputs.data, 1)
                    _, preds2 = torch.max(outputs2.data, 1)
                    loss = criterion(outputs, labels) + criterion(outputs2, labels2)
                elif opt.views == 3:
                    if return_feature:
                        logits, ff = outputs
                        logits2, ff2 = outputs2
                        logits3, ff3 = outputs3
                        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                        fnorm2 = torch.norm(ff2, p=2, dim=1, keepdim=True)
                        fnorm3 = torch.norm(ff3, p=2, dim=1, keepdim=True)
                        ff = ff.div(fnorm.expand_as(ff))  # 8*512,tensor
                        ff2 = ff2.div(fnorm2.expand_as(ff2))
                        ff3 = ff3.div(fnorm3.expand_as(ff3))
                        loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3)
                        _, preds = torch.max(logits.data, 1)
                        _, preds2 = torch.max(logits2.data, 1)
                        _, preds3 = torch.max(logits3.data, 1)
                        # Multiple perspectives are combined to calculate losses, please join ''--loss_merge'' in run.sh
                        if opt.loss_merge:
                            ff_all = torch.cat((ff, ff2, ff3), dim=0)
                            labels_all = torch.cat((labels, labels2, labels3), dim=0)
                        if opt.extra_Google:
                            logits4, ff4 = outputs4
                            fnorm4 = torch.norm(ff4, p=2, dim=1, keepdim=True)
                            ff4 = ff4.div(fnorm4.expand_as(ff4))
                            loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3) +criterion(logits4, labels4)
                            if opt.loss_merge:
                                ff_all = torch.cat((ff_all, ff4), dim=0)
                                labels_all = torch.cat((labels_all, labels4), dim=0)
                        if opt.arcface:
                            if opt.loss_merge:
                                loss += criterion_arcface(ff_all, labels_all)
                            else:
                                loss += criterion_arcface(ff, labels) + criterion_arcface(ff2, labels2) + criterion_arcface(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_arcface(ff4, labels4)  # /now_batch_size
                        if opt.cosface:
                            if opt.loss_merge:
                                loss += criterion_cosface(ff_all, labels_all)
                            else:
                                loss += criterion_cosface(ff, labels) + criterion_cosface(ff2, labels2) + criterion_cosface(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_cosface(ff4, labels4)  # /now_batch_size
                        if opt.circle:
                            if opt.loss_merge:
                                loss += criterion_circle(*convert_label_to_similarity(ff_all, labels_all)) / now_batch_size
                            else:
                                loss += criterion_circle(*convert_label_to_similarity(ff, labels)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff2, labels2)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff3, labels3)) / now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_circle(*convert_label_to_similarity(ff4, labels4)) / now_batch_size
                        if opt.triplet:
                            if opt.loss_merge:
                                hard_pairs_all = miner(ff_all, labels_all)
                                loss += criterion_triplet(ff_all, labels_all, hard_pairs_all)
                            else:
                                hard_pairs = miner(ff, labels)
                                hard_pairs2 = miner(ff2, labels2)
                                hard_pairs3 = miner(ff3, labels3)
                                loss += criterion_triplet(ff, labels, hard_pairs) + criterion_triplet(ff2, labels2, hard_pairs2) + criterion_triplet(ff3, labels3, hard_pairs3)# /now_batch_size
                                if opt.extra_Google:
                                    hard_pairs4 = miner(ff4, labels4)
                                    loss += criterion_triplet(ff4, labels4, hard_pairs4)
                        if opt.lifted:
                            if opt.loss_merge:
                                loss += criterion_lifted(ff_all, labels_all)
                            else:
                                loss += criterion_lifted(ff, labels) + criterion_lifted(ff2, labels2) + criterion_lifted(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_lifted(ff4, labels4)
                        if opt.contrast:
                            if opt.loss_merge:
                                loss += criterion_contrast(ff_all, labels_all)
                            else:
                                loss += criterion_contrast(ff, labels) + criterion_contrast(ff2,labels2) + criterion_contrast(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_contrast(ff4, labels4)
                        if opt.sphere:
                            if opt.loss_merge:
                                loss += criterion_sphere(ff_all, labels_all) / now_batch_size
                            else:
                                loss += criterion_sphere(ff, labels) / now_batch_size + criterion_sphere(ff2, labels2) / now_batch_size + criterion_sphere(ff3, labels3) / now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_sphere(ff4, labels4)

                    else:
                        _, preds = torch.max(outputs.data, 1)
                        _, preds2 = torch.max(outputs2.data, 1)
                        _, preds3 = torch.max(outputs3.data, 1)
                        if opt.loss_merge:
                            outputs_all = torch.cat((outputs, outputs2, outputs3), dim=0)
                            labels_all = torch.cat((labels, labels2, labels3), dim=0)
                            if opt.extra_Google:
                                outputs_all = torch.cat((outputs_all, outputs4), dim=0)
                                labels_all = torch.cat((labels_all, labels4), dim=0)
                            loss = 4*criterion(outputs_all, labels_all)
                        else:
                            loss = criterion(outputs, labels) + criterion(outputs2, labels2) + criterion(outputs3, labels3)
                            if opt.extra_Google:
                                loss += criterion(outputs4, labels4)

                # backward + optimize only if in training phase
                if epoch < opt.warm_epoch and phase == 'train':
                    warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
                    loss *= warm_up

                if phase == 'train':
                    if fp16:  # we use optimier to backward loss
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    optimizer.step()
                    ##########
                    if opt.moving_avg < 1.0:
                        update_average(model_test, model, opt.moving_avg)

                # statistics
                if int(version[0]) > 0 or int(version[2]) > 3:  # for the new version like 0.4.0, 0.5.0 and 1.0.0
                    running_loss += loss.item() * now_batch_size
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_loss += loss.data[0] * now_batch_size
                running_corrects += float(torch.sum(preds == labels.data))
                running_corrects2 += float(torch.sum(preds2 == labels2.data))
                if opt.views == 3:
                    running_corrects3 += float(torch.sum(preds3 == labels3.data))

            epoch_loss = running_loss / dataset_sizes['satellite']
            epoch_acc = running_corrects / dataset_sizes['satellite']
            epoch_acc2 = running_corrects2 / dataset_sizes['satellite']

            if opt.views == 2:
                print('{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc,
                                                                                         epoch_acc2))
            elif opt.views == 3:
                epoch_acc3 = running_corrects3 / dataset_sizes['satellite']
                print('{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f} Drone_Acc: {:.4f}'.format(phase,
                                                                                                           epoch_loss,
                                                                                                           epoch_acc,
                                                                                                           epoch_acc2,
                                                                                                           epoch_acc3))

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            if phase == 'train':
                scheduler.step()
            last_model_wts = model.state_dict()
            if epoch % 20 == 19:
                save_network(model, opt.name, epoch)
            # draw_curve(epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))
    # save_network(model_test, opt.name+'adapt', epoch)

    return model
예제 #7
0
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    #best_model_wts = model.state_dict()
    #best_acc = 0.0
    warm_up = 0.1  # We start from the 0.1*lrRate
    warm_iteration = round(dataset_sizes['train'] /
                           opt.batchsize) * opt.warm_epoch  # first 5 epoch
    if opt.circle:
        criterion_circle = CircleLoss(m=0.25, gamma=32)
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs, labels = data
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < opt.batchsize:  # skip the last batch
                    continue
                #print(inputs.shape)
                # wrap them in Variable
                if use_gpu:
                    inputs = Variable(inputs.cuda().detach())
                    labels = Variable(labels.cuda().detach())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
                # if we use low precision, input also need to be fp16
                #if fp16:
                #    inputs = inputs.half()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                if phase == 'val':
                    with torch.no_grad():
                        outputs = model(inputs)
                else:
                    outputs = model(inputs)

                sm = nn.Softmax(dim=1)
                if opt.circle:
                    logits, ff = outputs
                    fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                    ff = ff.div(fnorm.expand_as(ff))
                    loss = criterion(logits, labels) + criterion_circle(
                        *convert_label_to_similarity(ff,
                                                     labels)) / now_batch_size
                    #loss = criterion_circle(*convert_label_to_similarity( ff, labels))
                    _, preds = torch.max(logits.data, 1)

                elif not opt.PCB:  #  norm
                    _, preds = torch.max(outputs.data, 1)
                    loss = criterion(outputs, labels)
                else:  # PCB
                    part = {}
                    num_part = 6
                    for i in range(num_part):
                        part[i] = outputs[i]

                    score = sm(part[0]) + sm(part[1]) + sm(part[2]) + sm(
                        part[3]) + sm(part[4]) + sm(part[5])
                    _, preds = torch.max(score.data, 1)

                    loss = criterion(part[0], labels)
                    for i in range(num_part - 1):
                        loss += criterion(part[i + 1], labels)

                # backward + optimize only if in training phase
                if epoch < opt.warm_epoch and phase == 'train':
                    warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
                    loss = loss * warm_up

                if phase == 'train':
                    if fp16:  # we use optimier to backward loss
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    optimizer.step()

                # statistics
                if int(version[0]) > 0 or int(
                        version[2]
                ) > 3:  # for the new version like 0.4.0, 0.5.0 and 1.0.0
                    running_loss += loss.item() * now_batch_size
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_loss += loss.data[0] * now_batch_size
                running_corrects += float(torch.sum(preds == labels.data))

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            if phase == 'val':
                last_model_wts = model.state_dict()
                if epoch % 10 == 9:
                    save_network(model, epoch)
                draw_curve(epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    #print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, 'last')
    return model