Ejemplo n.º 1
0
def train(model):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    data_iter = data_loader.get_train_loader(batch_size=args.batch_size)
    val_iter = data_loader.get_val_loader(batch_size=args.batch_size)

    for epoch in range(args.epochs):
        model.train()

        if epoch % 100 == 0:
            print('Save checkpoint')
            torch.save(model, './result/models/model_{}_'.format(epoch) \
                       + args.exp_name + '.pth')

        run_loss = 0.0

        for idx, data in enumerate(data_iter):
            data = utils.to_var(data)
            ret = model.run_on_batch(data, optimizer, epoch)

            run_loss += ret['loss'].item()

            print('\r Progress epoch {}, {:.2f}%, average loss {}'.format(epoch, (idx + 1) * 100.0 / len(data_iter), run_loss / (idx + 1.0))),

        if epoch % 10 == 0:
            evaluate(model, val_iter)
Ejemplo n.º 2
0
def train(model, fine_tune, pseudo, num_epochs=100, data_sets=None):
    init_lr = 0.0001
    criterion = nn.BCELoss()

    if fine_tune:
        arch = model.name

        if arch.startswith('resnet') or arch.startswith("inception"):
            dense_layers = model.fc
        elif arch.startswith("densenet") or arch.startswith("vgg"):
            dense_layers = model.classifier
        else:
            raise Exception('unknown model')

        optimizer_ft = optim.SGD(dense_layers.parameters(),
                                 lr=init_lr,
                                 momentum=0.9)
        init_lr = 0.001
    else:
        optimizer_ft = optim.SGD(model.parameters(), lr=init_lr, momentum=0.9)

    max_num = 2
    if pseudo:
        pseudo_data, valid_data = data_sets
        data_loaders = {
            'train': data_loader.get_pseudo_train_loader(model, pseudo_data),
            'valid': data_loader.get_val_loader(model, valid_data)
        }
        max_num += 2
    else:
        train_data, valid_data = data_sets
        data_loaders = {
            'train': data_loader.get_train_loader(model, train_data),
            'valid': data_loader.get_val_loader(model, valid_data)
        }

    model = train_model(model,
                        criterion,
                        optimizer_ft,
                        lr_scheduler,
                        max_num=max_num,
                        init_lr=init_lr,
                        num_epochs=num_epochs,
                        data_loaders=data_loaders,
                        fine_tune=fine_tune,
                        pseudo=pseudo)
    return model
Ejemplo n.º 3
0
def validate(net, base_dir='val'):
    print('Begin validation')

    net.eval()
    accuracy = []
    for i, ptype in enumerate(ptypes):
        if i > 0:
            return np.mean(accuracy)
        csv_path = os.path.join(base_dir, csv_base_path.format(ptype))

        loader = get_val_loader(base_dir, csv_path)

        dists = []
        for pairs, labels in iter(loader):
            img_a = Variable(pairs[0]).type(Tensor)
            img_b = Variable(pairs[1]).type(Tensor)
            # img_b = Tensor(pairs[1])

            _, embs_a = net(img_a)
            _, embs_b = net(img_b)

            embs_a = embs_a.data
            embs_b = embs_b.data

            for i in range(len(embs_a)):
                cos_dis = embs_a[i].dot(
                    embs_b[i]) / (embs_a[i].norm() * embs_b[i].norm() + 1e-5)
                dists.append([cos_dis, int(labels[i])])

        dists = np.array(dists)

        tprs = []
        fprs = []
        accuracy = []
        thd = []

        folds = KFold(n=len(loader), n_folds=5, shuffle=False)
        thresh = np.arange(-1.0, 1.0, 0.005)
        for idx, (train, test) in enumerate(folds):
            best_thresh = find_best_threshold(thresh, dists[train])
            tpr, fpr, acc = eval_acc(best_thresh, dists[test])
            tprs += [tpr]
            fprs += [fpr]
            accuracy += [acc]
            thd.append(best_thresh)
        # Compute ROC curve and ROC area for each class
        # fpr = dict()
        # tpr = dict()
        # roc_auc = dict()
        # for i in range(n_classes):
        #     fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
        #     roc_auc[i] = auc(fpr[i], tpr[i])

        print(
            'PTYPE={} TPR={:.4f} FPR={:.4f} ACC={:.4f} std={:.4f} thd={:.4f}'.
            format(ptype, np.mean(tprs), np.mean(fprs), np.mean(accuracy),
                   np.std(accuracy), np.mean(thd)))
Ejemplo n.º 4
0
    csv_base_path = 'pairs_withlabel-CSV/{}_val.csv'
    for ptype in ptypes:
=======

def validate(net, base_dir='val'):
    print('Begin validation')

    net.eval()
    accuracy = []
    for i, ptype in enumerate(ptypes):
        if i > 0:
            return np.mean(accuracy)
>>>>>>> d5d57c0ffdab6a2eac16d8809ae66bd6ab8f5f19
        csv_path = os.path.join(base_dir, csv_base_path.format(ptype))

        loader = get_val_loader(base_dir, csv_path)

        dists = []
        for pairs, labels in iter(loader):
<<<<<<< HEAD
            img_a = torch.FloatTensor(pairs[0]).cuda()
            img_b = torch.FloatTensor(pairs[1]).cuda()
=======
            img_a = Variable(pairs[0]).type(Tensor)
            img_b = Variable(pairs[1]).type(Tensor)
            # img_b = Tensor(pairs[1])
>>>>>>> d5d57c0ffdab6a2eac16d8809ae66bd6ab8f5f19

            _, embs_a = net(img_a)
            _, embs_b = net(img_b)
Ejemplo n.º 5
0
    def __init__(self, args, model, optimizer, lr_policy):
        self.args = args
        self.lr_policy = lr_policy
        self.iter_wise = self.lr_policy.iteration_wise

        # for loggin the training
        val_head = [
            "iter" if self.iter_wise else "epoch", "mean_pixel_accuracy"
        ]
        for i in range(self.args.class_num):
            val_head.append("mean_precision_class_{}".format(i))
        for i in range(self.args.class_num):
            val_head.append("mean_IoU_class_{}".format(i))
        self.tlog = self.get_train_logger(
            {
                "train": [
                    "iter" if self.iter_wise else "epoch",
                    "batch_mean_total_loss"
                ],
                "val":
                val_head
            },
            save_dir=self.args.save_dir,
            save_name=self.args.save_name,
            arguments=self.get_argparse_arguments(self.args),
            use_http_server=self.args.use_http_server,
            use_msg_server=self.args.use_msg_server,
            notificate=False,
            visualize_fetch_stride=self.args.viz_fetch_stride,
            http_port=self.args.http_server_port,
            msg_port=self.args.msg_server_port)

        # paths
        self.save_dir = self.tlog.log_save_path
        self.model_param_dir = self.tlog.mkdir("model_param")

        if torch.cuda.is_available() and not self.args.nogpu:
            self.map_device = torch.device('cuda:{}'.format(
                self.args.gpu_device_num))
        else:
            self.map_device = torch.device('cpu')

        self.model = model
        if torch.cuda.is_available() and not args.nogpu:
            self.model = self.model.to(self.map_device)

        self.optimizer = optimizer

        self.train_loader = data_loader.get_train_loader(
            self.args,
            [(0.5, 0.5, 0.5),
             (0.5, 0.5, 0.5)])  #[(0.485, 0.456, 0.406),(0.229, 0.224, 0.225)])
        self.val_loader = data_loader.get_val_loader(self.args,
                                                     [(0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5)])

        self.cmap = self._gen_cmap()

        if self.args.show_parameters:
            for idx, m in enumerate(model.modules()):
                print(idx, '->', m)
            print(args)

        print("\nsaving at {}\n".format(self.save_dir))
Ejemplo n.º 6
0
                            help='Root directory of data (assumed to containing pairs list labels)')
    parser.add_argument('--data_dir', '-d', type=str, default=sys_home() + '/datasets/FIW/RFIW/val/',
                        help='Root directory of data (assumed to contain valdata)')

    args = parser.parse_args()

    net = net_sphere.sphere20a(classnum=300)

    if cuda:
        net.cuda()

    epoch, bess_acc = TorchTools.load_checkpoint(net, f_weights=args.modelpath)

    ncols = int(np.ceil(len(do_types) / 2))
    nrows = 2
    f, axes = plt.subplots(nrows, ncols, sharex='all', sharey='all')

    for i, id in enumerate(do_types):
        if i < ncols:
            ax = axes[0, i]
        else:
            ax = axes[1, i - ncols]
        csv_file = os.path.join(args.label_dir, types[id] + '_val.csv')
        loader = get_val_loader(args.data_dir, csv_file)
        # f.subplot()
        auc_score = validate(net, loader, ax)

        print('{} pairs: {} (auc)'.format(types[id], auc_score))

    plt.savefig('roc.png')
    # optimizer = optim.Adam(net.parameters(), lr=args.lr)
    best_acc = 0
    if not args.train:
        print('Begin train')
        for epoch in range(args.n_epochs):
            train_set, train_loader = get_train_loader(
                image_size=args.img_size,
                batch_size=args.train_batch_size,
                train_steps=args.train_steps,
                val_steps=args.val_steps,
                one_to_zero_train=args.one_to_zero_train,
                one_to_zero_val=args.one_to_zero_val)
            val_loader = get_val_loader(
                image_size=args.img_size,
                batch_size=args.val_batch_size,
                train_steps=args.train_steps,
                val_steps=args.val_steps,
                one_to_zero_train=args.one_to_zero_train,
                one_to_zero_val=args.one_to_zero_val)
            print("epoch:", epoch)

            #    if epoch in args.change_lr_for_epochs:
            #        args.lr *= 0.1
            #        optimizer = optim.SGD(param_groups, lr=args.lr, momentum=0.9, weight_decay=5e-4)

            train(net, net2, optimizer, epoch, train_loader)
            acc, val_loss = validate(net, net2, val_loader)
            print('accuracy = ', acc)
            scheduler.step(val_loss)

            if best_acc < acc: