コード例 #1
0
def main(args):
    train_loader, val_loader = create_dataloaders(args.batch_size)

    model = Net().to(device)
    optim = torch.optim.Adam(model.parameters())
    lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                             patience=1,
                                                             verbose=True)
    criterion = torch.nn.CrossEntropyLoss()

    best_accuracy = 0
    for epoch in range(1, args.epochs + 1):
        model.train()
        train_loss, train_accuracy = do_epoch(model,
                                              train_loader,
                                              criterion,
                                              optim=optim)

        model.eval()
        with torch.no_grad():
            val_loss, val_accuracy = do_epoch(model,
                                              val_loader,
                                              criterion,
                                              optim=None)

        tqdm.write(
            f'EPOCH {epoch:03d}: train_loss={train_loss:.4f}, train_accuracy={train_accuracy:.4f} '
            f'val_loss={val_loss:.4f}, val_accuracy={val_accuracy:.4f}')

        if val_accuracy > best_accuracy:
            print('Saving model...')
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), 'trained_models/source.pt')

        lr_schedule.step(val_loss)
コード例 #2
0
def main(args):
    if args.adapt_setting == 'svhn2mnist':
        target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
    elif args.adapt_setting == 'mnist2usps':
        target_dataset = ImageClassdata(txt_file=args.tar_list, root_dir=args.tar_root, img_type=args.img_type,
                                        transform=transforms.Compose([
                                            transforms.Resize(28),
                                            transforms.ToTensor(),
                                        ]))
    else:
        raise NotImplementedError
    dataloader = DataLoader(target_dataset, batch_size=args.batch_size, shuffle=False,
                            drop_last=False, num_workers=1, pin_memory=True)

    model = Net().to(device)
    model.load_state_dict(torch.load(args.MODEL_FILE))
    model.eval()

    total_accuracy = 0
    with torch.no_grad():
        for x, y_true in tqdm(dataloader, leave=False):
            x, y_true = x.to(device), y_true.to(device)
            y_pred = model(x)
            total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item()
    
    mean_accuracy = total_accuracy / len(dataloader)
    print(f'Accuracy on target data: {mean_accuracy:.4f}')
コード例 #3
0
ファイル: app.py プロジェクト: arohanajit/haze-detection
def get_prediction(image_bytes):
    model = Net()
    model.load_state_dict(torch.load('model.pt', map_location='cpu'),
                          strict=False)
    model.eval()
    tensor = transform_image(image_bytes=image_bytes)
    outputs = F.softmax(model(tensor), dim=1)
    top_p, top_class = outputs.topk(1, dim=1)
    return top_p, top_class
コード例 #4
0
def main():
    model = Net()
    model.load_state_dict(torch.load(checkpoint_path))
    model.eval()
    tar = prep_data()
    output = model(Variable(tar))
    res = output.cpu().data.numpy()
    res_ = np.squeeze(res)
    num = np.argwhere(res_ == np.max(res_))
    print(int(num))
コード例 #5
0
def main():

    # load in a haar cascade classifier for detecting frontal faces
    face_cascade = cv2.CascadeClassifier(
        'detector_architectures/haarcascade_frontalface_default.xml')

    model = Net()
    model.load_state_dict(torch.load('./saved_models/keypoints_model_1.pt'))
    model.eval()

    show_webcam(model, face_cascade)
コード例 #6
0
def main(args):
    X_target, y_target = preprocess_test()
    target_dataset = torch.utils.data.TensorDataset(X_target, y_target)
    target_loader = DataLoader(target_dataset, batch_size=args.batch_size,
                               shuffle=False, num_workers=1, pin_memory=True)

    model = Net().to(device)
    model.load_state_dict(torch.load(args.MODEL_FILE))
    model.eval()

    total_accuracy = 0
    with torch.no_grad():
        for x, y_true in tqdm(target_loader, leave=False):
            x, y_true = x.to(device), y_true.to(device)
            y_pred = model(x)
            total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item()
    
    mean_accuracy = total_accuracy / len(target_loader)
    print(f'Accuracy on target data: {mean_accuracy:.4f}')
コード例 #7
0
def main():
    args = get_args()
    print(args)

    # set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = 'cpu'

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    net = Net(device=device, mode=args.target, target_mode=args.target_mode)
    net = net.eval()
    net = net.to(device)
    load_model(
        net,
        device,
        fullpath='trained_models/Net_continuous_trained/checkpoint_274.pth.tar'
    )

    imgs_dir = '/media/yotamg/bd0eccc9-4cd5-414c-b764-c5a7890f9785/Yotam/Real-Images/png'
    imgs_filelist = [
        os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)
        if img.endswith('.png')
    ]

    for i, img in enumerate(imgs_filelist):
        # x,x_paths, y, y_paths = data
        x = plt.imread(img)
        x = np.expand_dims(x, 0)
        x = np.transpose(x, (0, 3, 1, 2))
        x = x[:, :, 2:-2, 8:-8]
        x = torch.Tensor(x).to(device)
        with torch.no_grad():
            out = net(x)
        out = out.detach().cpu().numpy()
        x = x.detach().cpu().numpy()
        plt.figure(1)
        if args.target_mode == 'discrete':
            out = np.argmax(out, axis=1)
            out = out[0]
        # out = np.squeeze(out,0)
        out = (out - np.min(out)) / (np.max(out) - np.min(out))
        ax1 = plt.subplot(1, 3, 1)
        # x = (x + 1) / 2
        ax1.imshow(np.transpose(x[0], (1, 2, 0)))
        ax3 = plt.subplot(1, 3, 3, sharex=ax1, sharey=ax1)
        ax3.imshow(out, cmap='jet')
        # plt.suptitle(label, fontsize="large")
        plt.show()
コード例 #8
0
ファイル: eval.py プロジェクト: ekrim/deep-homography
def pred_homography(batch,
                    patch_size,
                    model_file='homography_model.pytorch',
                    scale=32):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    batch = normalize_batch(batch).to(device)
    net = Net()
    net.load_state_dict(torch.load(model_file))
    net.eval().to(device)
    with torch.no_grad():
        output = net(batch).detach().cpu().numpy() * scale
        mean_shift = np.mean(output, axis=0)

        pts1 = np.float32(
            [0, 0, patch_size, 0, patch_size, patch_size, 0,
             patch_size]).reshape(-1, 1, 2)
        pts2 = mean_shift.reshape(-1, 1, 2) + pts1

        h = np.linalg.inv(cv2.findHomography(pts1, pts2)[0])

        return h
コード例 #9
0
def main(args):
    dataset = MNISTM(train=False)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            drop_last=False,
                            num_workers=1,
                            pin_memory=True)

    model = Net().to(device)
    model.load_state_dict(torch.load(args.MODEL_FILE))
    model.eval()

    total_accuracy = 0
    with torch.no_grad():
        for x, y_true in tqdm(dataloader, leave=False):
            x, y_true = x.to(device), y_true.to(device)
            y_pred = model(x)
            total_accuracy += (
                y_pred.max(1)[1] == y_true).float().mean().item()

    mean_accuracy = total_accuracy / len(dataloader)
    print(f'Accuracy on target data: {mean_accuracy:.4f}')
コード例 #10
0
def train(options):
    exp_name = options['exp_name']
    batch_size = options['batch_size']
    use_pca = options['use_pca']
    model_type = options['model_type']
    loss_fn = options['loss_fn']
    optim = options['optim']
    use_scheduler = options['use_scheduler']
    lr = options['lr']
    epochs = options['epochs']
    pca_var_hold = options['pca_var_hold']
    debug_mode = options['debug_mode']
    
    if os.path.exists(exp_name):
        shutil.rmtree(exp_name)

    time.sleep(1)
    writer = SummaryWriter(exp_name,flush_secs=1)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X = os.listdir('hilbert_data')
    X_train = X[:int(0.8*len(X))]
    X_test = X[int(0.8*len(X)):]
    # X = np.load('bined_x.npy')
    # y = np.load('bined_y.npy')
    # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    # if use_pca and 'Raw' in exp_name:
    #     scaler = PCA(pca_var_hold)
    #     scaler.fit(X_train)
    #     X_train = scaler.transform(X_train)
    #     X_test = scaler.transform(X_test)

    # needed_dim = X_train.shape[1]

    dataset_train = HIL_MOOD(X_train, model_type=model_type,data_type='train',debug_mode=debug_mode)
    train_loader = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
    
    dataset_val = HIL_MOOD(X_test, model_type=model_type,data_type='val')
    valid_loader = DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)
    
    model = Net()
    model.to(device)
    if optim == None:
        print('you need to specify an optimizer')
        exit()
    elif optim == 'adam':
        optimizer = torch.optim.Adam(   model.parameters(), lr=lr)
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(   model.parameters(), lr=lr,momentum=0.9)
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',verbose=True,threshold=0.0001,patience = 10)
    if loss_fn == None:
        print('you need to specify an optimizer')
        exit()
    else:

        if loss_fn == 'mse':

            loss_fn = torch.nn.MSELoss()
        elif loss_fn == 'cross_entropy':
            loss_fn = torch.nn.CrossEntropyLoss()
    
    
    
    mean_train_losses = []
    mean_valid_losses = []
    valid_acc_list = []
    best = 0  #small number for acc big number for loss to save a model
    
    for epoch in range(epochs):
        model.train()
        train_losses = []
        valid_losses = []
        for i, (images, labels) in enumerate(train_loader):
            if images.shape[0] != batch_size:
                continue
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            # print(images.shape)
            outputs = model(images)
            # print(images.shape)
            # print(outputs.shape)
            # print(labels.shape)
            # print(i)
            loss =loss_fn(outputs,labels)
            # print('loss: ',loss.item())
            writer.add_scalar('Loss/train', loss.item(), len(train_loader)*epoch+i)

            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            del outputs
            # if (i * batch_size) % (batch_size * 100) == 0:
            #     print(f'{i * batch_size} / 50000')
                
        model.eval()
        correct_5_2 = 0
        correct_5_1 = 0
        
        total_loss = 0
        total = 0
        accsat =[0.5,0.05,0.005]
        accs = np.zeros(len(accsat))
        # corrs = np.zeros(len(accsat))
        correct_array = np.zeros(len(accsat))
        with torch.no_grad():
            for i, (images, labels) in enumerate(valid_loader):
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss =  loss_fn(outputs, labels)

                
                for i in range(len(accsat)):

                    correct_array[i] += accat(outputs,labels,thresh=accsat[i])

                # total_loss += loss.item()
                total += labels.size(0)
                
                
                valid_losses.append(loss.item())


                
        mean_train_losses.append(np.mean(train_losses))
        mean_valid_losses.append(np.mean(valid_losses))
        # scheduler.step(np.mean(valid_losses))
        for i in range(len(accsat)):
            accs[i] = 100*correct_array[i]/total
            writer.add_scalar('Acc/val_@'+str(accsat[i]), accs[i], epoch)
        
        if np.mean(valid_losses) < best:
            best = np.mean(valid_losses)
            torch.save(model.state_dict(),os.path.join(os.getcwd(),'models','meh.pth'))
        
        writer.add_scalar('Loss/val', np.mean(valid_losses), epoch)
        # valid_acc_list.append(accuracy)
        if epoch ==epochs-1:
            print('epoch : {}, train loss : {:.4f}, valid loss : {:.4f}, [email protected] : {:.4f}'\
                .format(epoch+1, np.mean(train_losses), np.mean(valid_losses), accsat[1]))
コード例 #11
0
def main(args):
    source_model = Net().to(device)
    source_model.load_state_dict(torch.load(args.MODEL_FILE))
    source_model.eval()
    set_requires_grad(source_model, requires_grad=False)

    clf = source_model
    source_model = source_model.feature_extractor

    target_model = Net().to(device)
    target_model.load_state_dict(torch.load(args.MODEL_FILE))
    target_model = target_model.feature_extractor
    target_clf = clf.classifier

    discriminator = nn.Sequential(nn.Linear(320, 50), nn.ReLU(),
                                  nn.Linear(50, 20), nn.ReLU(),
                                  nn.Linear(20, 1)).to(device)

    half_batch = args.batch_size // 2
    source_dataset = MNIST(config.DATA_DIR / 'mnist',
                           train=True,
                           download=True,
                           transform=Compose([GrayscaleToRgb(),
                                              ToTensor()]))
    source_loader = DataLoader(source_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    target_dataset = MNISTM(train=False)
    target_loader = DataLoader(target_dataset,
                               batch_size=half_batch,
                               shuffle=True,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(target_model.parameters())
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, args.epochs + 1):
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        total_accuracy = 0
        target_label_accuracy = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(target_model, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            for _ in range(args.k_disc):
                (source_x, _), (target_x, _) = next(batch_iterator)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((
                    preds >
                    0).long() == discriminator_y.long()).float().mean().item()

            # Train classifier
            set_requires_grad(target_model, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            for _ in range(args.k_clf):
                _, (target_x, target_labels) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                loss = criterion(preds, discriminator_y)

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

                target_label_preds = target_clf(target_features)
                target_label_accuracy += (target_label_preds.cpu().max(1)[1] ==
                                          target_labels).float().mean().item()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_accuracy = total_accuracy / (args.iterations * args.k_disc)
        target_mean_accuracy = target_label_accuracy / (args.iterations *
                                                        args.k_clf)
        tqdm.write(
            f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, '
            f'discriminator_accuracy={mean_accuracy:.4f}, target_accuracy={target_mean_accuracy:.4f}'
        )

        # Create the full target model and save it
        clf.feature_extractor = target_model
        torch.save(clf.state_dict(), 'trained_models/adda.pt')
コード例 #12
0
def main(args):
    source_model = Net().to(device)
    source_model.load_state_dict(torch.load(args.MODEL_FILE))
    source_model.eval()
    set_requires_grad(source_model, requires_grad=False)

    clf = source_model
    source_model = source_model.feature_extractor

    target_model = Net().to(device)
    target_model.load_state_dict(torch.load(args.MODEL_FILE))
    target_model = target_model.feature_extractor

    classifier = clf.classifier

    discriminator = nn.Sequential(nn.Linear(EXTRACTED_FEATURE_DIM, 64),
                                  nn.ReLU(), nn.BatchNorm1d(64),
                                  nn.Linear(64, 1), nn.Sigmoid()).to(device)

    #half_batch = args.batch_size // 2

    batch_size = args.batch_size

    # X_source, y_source = preprocess_train()
    X_source, y_source = preprocess_train_single(1)
    source_dataset = torch.utils.data.TensorDataset(X_source, y_source)

    source_loader = DataLoader(source_dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=1,
                               pin_memory=True)

    X_target, y_target = preprocess_test(args.person)
    target_dataset = torch.utils.data.TensorDataset(X_target, y_target)
    target_loader = DataLoader(target_dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=1,
                               pin_memory=True)

    discriminator_optim = torch.optim.Adam(discriminator.parameters())
    target_optim = torch.optim.Adam(target_model.parameters(), lr=3e-6)
    criterion = nn.BCEWithLogitsLoss()
    criterion_class = nn.CrossEntropyLoss()

    best_tar_acc = test(args, clf)
    final_accs = []

    for epoch in range(1, args.epochs + 1):
        source_loader = DataLoader(source_loader.dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
        target_loader = DataLoader(target_loader.dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
        batch_iterator = zip(loop_iterable(source_loader),
                             loop_iterable(target_loader))

        total_loss = 0
        adv_loss = 0
        total_accuracy = 0
        second_acc = 0
        total_class_loss = 0
        for _ in trange(args.iterations, leave=False):
            # Train discriminator
            set_requires_grad(target_model, requires_grad=False)
            set_requires_grad(discriminator, requires_grad=True)
            discriminator.train()
            for _ in range(args.k_disc):
                (source_x, source_y), (target_x, _) = next(batch_iterator)
                source_y = source_y.to(device).view(-1)
                source_x, target_x = source_x.to(device), target_x.to(device)

                source_features = source_model(source_x).view(
                    source_x.shape[0], -1)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)

                discriminator_x = torch.cat([source_features, target_features])
                discriminator_y = torch.cat([
                    torch.ones(source_x.shape[0], device=device),
                    torch.zeros(target_x.shape[0], device=device)
                ])

                preds = discriminator(discriminator_x).squeeze()
                loss = criterion(preds, discriminator_y)

                discriminator_optim.zero_grad()
                loss.backward()
                discriminator_optim.step()

                total_loss += loss.item()
                total_accuracy += ((preds >= 0.5).long() == discriminator_y.
                                   long()).float().mean().item()

            # Train feature extractor
            set_requires_grad(target_model, requires_grad=True)
            set_requires_grad(discriminator, requires_grad=False)
            target_model.train()
            for _ in range(args.k_clf):
                _, (target_x, _) = next(batch_iterator)
                target_x = target_x.to(device)
                target_features = target_model(target_x).view(
                    target_x.shape[0], -1)
                source_features = target_model(source_x).view(
                    source_x.shape[0], -1)
                source_pred = classifier(source_features)  # (batch_size, 4)

                # flipped labels
                discriminator_y = torch.ones(target_x.shape[0], device=device)

                preds = discriminator(target_features).squeeze()
                second_acc += ((preds >= 0.5).long() == discriminator_y.long()
                               ).float().mean().item()

                loss_adv = criterion(preds, discriminator_y)
                adv_loss += loss_adv.item()
                loss_class = criterion_class(source_pred, source_y)
                total_class_loss += loss_class.item()
                loss = loss_adv  #+ 0.001*loss_class

                target_optim.zero_grad()
                loss.backward()
                target_optim.step()

        mean_loss = total_loss / (args.iterations * args.k_disc)
        mean_adv_loss = adv_loss / (args.iterations * args.k_clf)
        total_class_loss = total_class_loss / (args.iterations * args.k_clf)
        dis_accuracy = total_accuracy / (args.iterations * args.k_disc)
        sec_acc = second_acc / (args.iterations * args.k_clf)
        clf.feature_extractor = target_model
        tar_accuarcy = test(args, clf)
        final_accs.append(tar_accuarcy)
        if tar_accuarcy > best_tar_acc:
            best_tar_acc = tar_accuarcy
            torch.save(clf.state_dict(), 'trained_models/adda.pt')

        tqdm.write(
            f'EPOCH {epoch:03d}: discriminator_loss={mean_loss:.4f}, adv_loss = {mean_adv_loss:.4f}, '
            f'discriminator_accuracy={dis_accuracy:.4f}, tar_accuary = {tar_accuarcy:.4f}, best_accuracy = {best_tar_acc:.4f}, '
            f'sec_acc = {sec_acc:.4f}, total_class_loss: {total_class_loss:.4f}'
        )

        # Create the full target model and save it
        clf.feature_extractor = target_model
        #torch.save(clf.state_dict(), 'trained_models/adda.pt')
    jd = {"test_acc": final_accs}
    with open(str(args.seed) + '/acc' + str(args.person) + '.json', 'w') as f:
        json.dump(jd, f)
コード例 #13
0
class Trainer(object):
    def __init__(self, train_loader, test_loader, config):
        self.config = config
        self.device = config.device

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.n_epoch = config.n_epoch
        self.lr = config.lr
        self.gamma = config.gamma
        self.device = config.device

        # self.start_epoch = 1
        self.start_itr = 1

        n_classes = len(self.train_loader.dataset.classes)
        self.model = Net(n_classes=n_classes).to(self.device)
        print(self.model)
        print('Initialized model...\n')

        self.optim = torch.optim.Adadelta(self.model.parameters(), self.lr)
        self.scheduler = StepLR(self.optim, step_size=1, gamma=self.gamma)

        # if not self.config.model_state_path == '':
        #     self._load_models(self.config.model_state_path)

        self.writer = SummaryWriter(log_dir=self.config.log_dir)

    def train(self):
        self.model.train()

        n_itr = self.start_itr
        print('Start training...!')
        for epoch in range(1, self.n_epoch + 1):
            with tqdm(total=len(self.train_loader)) as pbar:
                for idx, (img, label) in enumerate(self.train_loader):
                    pbar.set_description(
                        f'Epoch[{epoch}/{self.n_epoch}], iteration[{idx}/{len(self.train_loader)}]'
                    )

                    img, label = img.to(self.device), label.to(self.device)

                    self.optim.zero_grad()
                    output = self.model(img)
                    loss = F.nll_loss(output, label)
                    loss.backward()
                    self.optim.step()

                    if n_itr % self.config.log_interval == 0:
                        pbar.set_postfix(OrderedDict(loss=loss.item()))
                        tqdm.write(
                            f'Epoch[{epoch}], iteration[{idx}/{len(self.train_loader)}], loss: {loss.item()}'
                        )
                        self.writer.add_scalar('loss/loss', loss.item(), n_itr)

                    if n_itr % self.config.checkpoint_interval == 0:
                        self._save_models(epoch, n_itr)

                    n_itr += 1
                    pbar.update()
            self.scheduler.step()
            self.test(n_itr)

        self.writer.close()

    def test(self, n_itr):
        self.model.eval()
        test_loss = 0
        correct = 0
        print('Start testing...!')
        with torch.no_grad():
            for _, (img, label) in enumerate(self.test_loader):
                img, label = img.to(self.device), label.to(self.device)
                output = self.model(img)
                test_loss += F.nll_loss(output, label, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(label.view_as(pred)).sum().item()

        test_loss /= len(self.test_loader.dataset)
        accuracy = correct / len(self.test_loader.dataset)
        self.writer.add_scalar('accuracy/test_accuracy', accuracy, n_itr)
        tqdm.write(
            f'Test: Average loss: {test_loss}, Accuracy: {accuracy * 100.0}%')
        self.model.train()

    def _save_models(self, epoch, n_itr):
        checkpoint_name = f'{self.config.dataset_name}_model_ckpt_{n_itr}.pt'
        checkpoint_path = os.path.join(self.config.checkpoint_dir,
                                       checkpoint_name)
        torch.save(
            {
                # 'epoch': epoch,
                'n_itr': n_itr,
                'model': self.model.state_dict(),
                'optim': self.optim.state_dict(),
            },
            checkpoint_path)
        tqdm.write(f'Saved models state_dict: n_itr_{n_itr}')

    def _load_models(self, model_state_path):
        checkpoint = torch.load(model_state_path)
        # self.start_epoch = checkpoint['epoch']
        self.start_itr = checkpoint['n_itr'] + 1
        self.model.load_state_dict(checkpoint['model'])
        self.optim.load_state_dict(checkpoint['optim'])
        print(f'start_itr: {self.start_itr}')
        print('Loaded pretrained models...\n')
コード例 #14
0
# 利用已有的训练好的检测器检测人脸
face_cascade = cv2.CascadeClassifier(
    'detector_architectures/haarcascade_frontalface_default.xml')
faces = face_cascade.detectMultiScale(image, 1.2, 2)
# image_with_detections = image.copy()
# for (x, y, w, h) in faces:
#     cv2.rectangle(image_with_detections, (x, y), (x + w, y + h), (255, 0, 0), 3)
# fig = plt.figure(figsize=(9, 9))
# plt.imshow(image_with_detections)

# 利用自己训练的网络进行
# 1 加载训练好的网络
net = Net()
net.load_state_dict(
    torch.load('./saved_models/krunal_keypoints_model_lr0001_epoch20.pt'))
print(net.eval())


def show_points(image_test, key_points):
    """
    显示检测结果
    :param image_test:
    :param key_points:
    :return:
    """
    plt.figure()
    key_points = key_points.data.numpy()
    key_points = key_points * 60.0 + 68
    key_points = np.reshape(key_points, (68, -1))
    plt.imshow(image_test, cmap='gray')
    plt.scatter(key_points[:, 0], key_points[:, 1], s=50, marker='.', c='r')
コード例 #15
0
ファイル: test_new.py プロジェクト: YotYot/DL_project
def main():
    args = get_args()
    print(args)

    # set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    test_dir = '/media/yotamg/bd0eccc9-4cd5-414c-b764-c5a7890f9785/Yotam/Sintel/Filtered/rgb'
    label_dir = '/media/yotamg/bd0eccc9-4cd5-414c-b764-c5a7890f9785/Yotam/Sintel/Filtered/GT'

    test_filelist = [
        os.path.join(test_dir, img) for img in os.listdir(test_dir)
        if 'alley_1' in img
    ]
    test_labels_filelist = [
        img.replace(test_dir, label_dir).replace('_1100_maskImg.png',
                                                 '_GT.dpt')
        for img in test_filelist
    ]

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    test_dataset = Dataset(image_filelist=test_filelist,
                           label_filelist=test_labels_filelist,
                           transforms=transform,
                           pickle_name='test.pickle',
                           train=False)

    test_data_loader = torch.utils.data.DataLoader(test_dataset,
                                                   batch_size=1,
                                                   shuffle=True,
                                                   num_workers=8)

    net = Net(device=device, mode=args.target, target_mode=args.target_mode)
    net = net.eval()
    net = net.to(device)
    load_model(
        net,
        device,
        fullpath=
        '/home/yotamg/PycharmProjects/dfd/trained_models/Net_default/checkpoint_76.pth.tar'
    )

    for i, data in enumerate(test_data_loader):
        # x,x_paths, y, y_paths = data
        x, x_path, y, y_path = data
        x = x.to(device)
        with torch.no_grad():
            out = net(x)
        out = out.detach().cpu().numpy()
        x = x.detach().cpu().numpy()
        plt.figure(1)
        out = np.argmax(out, axis=1)
        out = np.squeeze(out, 0)
        out = (out - np.min(out)) / (np.max(out) - np.min(out))
        ax1 = plt.subplot(1, 3, 2)
        ax1.imshow(y[0])
        ax2 = plt.subplot(1, 3, 1, sharex=ax1, sharey=ax1)
        x = (x + 1) / 2
        ax2.imshow(np.transpose(x[0], (1, 2, 0)))
        ax3 = plt.subplot(1, 3, 3, sharex=ax1, sharey=ax1)
        ax3.imshow(out)
        # plt.suptitle(label, fontsize="large")
        plt.show()
コード例 #16
0
        loss.backward()
        optimizer.step()
        #记录误差 一个数
        loss_tr += loss.item()

    #累计当前批次的损失
    loss_tr_epoch = loss_tr / len(loader_tr)
    print(f'train done. {process_time()- start} sec, loss={loss_tr_epoch:.4f}')
    losses_tr.append(loss_tr_epoch)
    # 保存loss的数据与epoch数值
    writer.add_scalar('Train', loss_tr_epoch, epoch)

    loss_ts = 0
    #改成eval模式
    model.eval()
    for data in loader_ts:
        #x y 送入 gpu
        data = data.to(device)

        out = model(data)
        #print(out[0,:])
        #print(data.y[0,:])
        loss = criterion(out, data.y)

        #记录误差 一个数
        loss_ts += loss.item()

    #累计当前批次的损失
    loss_ts_epoch = loss_ts / len(loader_ts)
    print(f'test sec, loss={loss_ts_epoch:.4f}')
コード例 #17
0
def train(data_folder, output_folder, es_patience, epochs, TTA):
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
    print("使用デバイス:", device)

    train = pd.read_csv(data_folder + '/train.csv')
    test = pd.read_csv(data_folder + '/test.csv')

    arch = EfficientNet.from_pretrained('efficientnet-b1')  #モデル

    meta_features = list(train.columns)
    meta_features.remove('image_name')
    meta_features.remove('target')
    meta_features.remove('fold')

    #パラメータ各種
    oof = np.zeros((len(train), 1))
    preds = torch.zeros((len(test), 1), dtype=torch.float32, device=device)

    skf = KFold(n_splits=5, shuffle=True, random_state=47)

    #===========================
    # ループスタート
    #===========================
    for fold, (idxT, idxV) in enumerate(skf.split(np.arange(15)), 1):
        print('=' * 20, 'Fold', fold, '=' * 20)

        #kfold
        train_idx = train.loc[train['fold'].isin(idxT)].index
        val_idx = train.loc[train['fold'].isin(idxV)].index

        #学習パラメータ
        model_path = f'/model_{fold}.pth'
        best_val = 0
        patience = es_patience

        model = Net(arch=arch, n_meta_features=len(meta_features))

        model = model.to(device)

        #optimizer
        optim = torch.optim.Adam(model.parameters(), lr=0.001)
        #scheduler
        scheduler = ReduceLROnPlateau(optimizer=optim,
                                      mode='max',
                                      patience=1,
                                      verbose=True,
                                      factor=0.2)

        criterion = nn.BCEWithLogitsLoss()

        trainDataset = MelanomaDataset(
            df=train.iloc[train_idx].reset_index(drop=True),
            imfolder=data_folder + '/train',
            train=True,
            transforms=train_transform,
            meta_features=meta_features)

        valDataset = MelanomaDataset(
            df=train.iloc[val_idx].reset_index(drop=True),
            imfolder=data_folder + '/train',
            train=True,
            transforms=test_transform,
            meta_features=meta_features)

        testDataset = MelanomaDataset(df=test,
                                      imfolder=data_folder + '/test',
                                      train=False,
                                      transforms=test_transform,
                                      meta_features=meta_features)

        train_loader = DataLoader(dataset=trainDataset,
                                  batch_size=64,
                                  shuffle=True,
                                  num_workers=2)
        val_loader = DataLoader(dataset=valDataset,
                                batch_size=16,
                                shuffle=False,
                                num_workers=2)
        test_loader = DataLoader(dataset=testDataset,
                                 batch_size=16,
                                 shuffle=False,
                                 num_workers=2)

        #=====================
        # epochs
        #=====================
        for epoch in range(epochs):
            start_time = time.time()
            correct = 0
            epoch_loss = 0

            #train_loop
            model.train()
            for x, y in train_loader:
                x[0] = torch.tensor(x[0], device=device, dtype=torch.float32)
                x[1] = torch.tensor(x[1], device=device, dtype=torch.float32)
                y = torch.tensor(y, device=device, dtype=torch.float32)
                optim.zero_grad()
                z = model(x)
                loss = criterion(z, y.unsqueeze(1))
                loss.backward()
                optim.step()
                pred = torch.round(torch.sigmoid(z))
                correct += (pred.cpu() == y.cpu().unsqueeze(1)).sum().item()
                epoch_loss += loss.item()
            train_acc = correct / len(train_idx)

            model.eval()
            val_preds = torch.zeros((len(val_idx), 1),
                                    dtype=torch.float32,
                                    device=device)

            with torch.no_grad():
                #validation_loop
                for j, (x_val, y_val) in enumerate(val_loader):
                    x_val[0] = torch.tensor(x_val[0],
                                            device=device,
                                            dtype=torch.float32)
                    x_val[1] = torch.tensor(x_val[1],
                                            device=device,
                                            dtype=torch.float32)
                    y_val = torch.tensor(y_val,
                                         device=device,
                                         dtype=torch.float32)
                    z_val = model(x_val)
                    val_pred = torch.sigmoid(z_val)
                    val_preds[j *
                              val_loader.batch_size:j * val_loader.batch_size +
                              x_val[0].shape[0]] = val_pred

                val_acc = accuracy_score(train.iloc[val_idx]['target'].values,
                                         torch.round(val_preds.cpu()))
                val_roc = roc_auc_score(train.iloc[val_idx]['target'].values,
                                        val_preds.cpu())

                print(
                    'Epoch{:03}: | Loss:{:.3f} | Train acc:{:.3f} | Val acc:{:.3f} | Val roc_auc:{:.3f} | Training time:{}'
                    .format(
                        epoch + 1, epoch_loss, train_acc, val_acc, val_roc,
                        str(
                            datetime.timedelta(seconds=time.time() -
                                               start_time))[:7]))

                scheduler.step(val_roc)

                if val_roc >= best_val:
                    best_val = val_roc
                    patience = es_patience

                    torch.save(model, output_folder + model_path)

                else:
                    patience -= 1
                    if patience == 0:
                        print('Early stopping. Best Val roc_auc:{:.3f}'.format(
                            best_val))
                        break

        model = torch.load(output_folder + model_path)
        model.eval()
        val_preds = torch.zeros((len(val_idx), 1),
                                dtype=torch.float32,
                                device=device)

        #evaluation loop
        with torch.no_grad():
            for j, (x_val, y_val) in enumerate(val_loader):
                x_val[0] = torch.tensor(x_val[0],
                                        device=device,
                                        dtype=torch.float32)
                x_val[1] = torch.tensor(x_val[1],
                                        device=device,
                                        dtype=torch.float32)
                y_val = torch.tensor(y_val, device=device, dtype=torch.float32)
                z_val = model(x_val)
                val_pred = torch.sigmoid(z_val)
                val_preds[j * val_loader.batch_size:j * val_loader.batch_size +
                          x_val[0].shape[0]] = val_pred
            oof[val_idx] = val_preds.cpu().numpy()

            for _ in range(TTA):
                for i, x_test in enumerate(test_loader):
                    x_test[0] = torch.tensor(x_test[0],
                                             device=device,
                                             dtype=torch.float32)
                    x_test[1] = torch.tensor(x_test[1],
                                             device=device,
                                             dtype=torch.float32)
                    z_test = model(x_test)
                    z_test = torch.sigmoid(z_test)
                    preds[i *
                          test_loader.batch_size:i * test_loader.batch_size +
                          x_test[0].shape[0]] += z_test

                preds /= TTA

    preds /= skf.n_splits

    return preds, oof
コード例 #18
0
def test(pathModel, nnClassCount, testTensor, trBatchSize):
    print("\n\n\n")
    print("Inside test funtion")

    CLASS_NAMES = ['Broken', 'Normal']
    # cudnn.benchmark = True
    model = Net()
    # model = model.cuda()
    # -------------------- SETTINGS: NETWORK ARCHITECTURE, MODEL LOAD
    print("Is model==None:", model is None)
    print("Is pathModel==None:", pathModel is None)
    # cudnn.benchmark = True

    # if pathModel!=None:
    #     model = Trainer.loadModel(nnArchitecture, nnClassCount, nnIsTrained)
    #     #model = torch.nn.DataParallel(model)
    #     model.to(device)

    if os.path.isfile(pathModel):
        print("=> loading checkpoint: ", pathModel)
        modelCheckpoint = torch.load(pathModel, map_location='cpu')
        model.load_state_dict(modelCheckpoint['state_dict'], strict=False)
        print("=> loaded checkpoint: ", pathModel)
    else:
        print("=> no checkpoint found: ")

    print(
        "\n============================ Loading data into RAM ======================================== "
    )

    testImage = testTensor
    testSize = testImage.size()[0]

    print(
        "============================= Evaluation of model starts ===================================="
    )
    model.eval()

    broken = 0
    normal = 0
    batchID = 1

    with torch.no_grad():

        for i in range(0, testSize, trBatchSize):

            if (batchID % 1) == 0:
                print("batchID:" + str(batchID) + '/' +
                      str(testImage.size()[0] / trBatchSize))

            if i + trBatchSize >= testSize:
                input = testImage[i:]
            else:
                input = testImage[i:i + trBatchSize]

            input = trans_test(input)
            input = input.type(torch.FloatTensor)

            varInput = torch.autograd.Variable(input)

            out = model(varInput)

            _, predicted = torch.max(out.data, 1)

            print(predicted)

            if i + trBatchSize <= testSize:

                for k in range(trBatchSize):
                    if (predicted[k] == 1):
                        normal += 1
                    elif (predicted[k] == 0):
                        broken += 1
            else:
                for k in range(testSize % trBatchSize):
                    if (predicted[k] == 1):
                        normal += 1
                    elif (predicted[k] == 0):
                        broken += 1

            batchID += 1

    print(' Number of broken grains in sample : ', broken)

    print(' Number of normal grains in sample :', normal)
コード例 #19
0
def main():
    # file_structure = check_directories()
    # if file_structure == -1:
    #     print('\nERROR: Directories can\'t be created, error thrown')
    #     return -1
    # else:
    #     print('\nDirectories created successfully...\nLaunching camera module...')
    net = Net()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #
    # dev = torch.cuda.is_available()
    # # Assume that we are on a CUDA machine, then this should print a CUDA device:
    #
    # print(dev)

    net.to(device)
    net.load_state_dict(torch.load('saved_models/keypoints_model_2.pt'))

    ## print out your net and prepare it for testing (uncomment the line below)
    net.eval()

    # net.load_state_dict(torch.load('saved_models/keypoints_model_2.pt'))

    # Fire camera & launch streams
    # pyrs.start()
    serv = pyrs.Service()
    # cam = pyrs.Device(device_id = 0, streams = [pyrs.stream.ColorStream(fps=60),
    #                                             pyrs.stream.DepthStream(fps=60),
    #                                             pyrs.stream.CADStream(fps=60),
    #                                             pyrs.stream.DACStream(fps=60)])
    cam = serv.Device(
        device_id=0,
        streams=[
            pyrs.stream.ColorStream(fps=60),
            # pyrs.stream.DepthStream(fps=60),
            # pyrs.stream.CADStream(fps=60),
            # pyrs.stream.DACStream(fps=60)
        ])
    # scale = cam.depth_scale

    # Some important variables
    flag_save_frames = False  #
    file_num = 0
    # cap = cv2.VideoCapture(0)

    # Define the codec and create VideoWriter object
    # fourcc = cv2.cv.CV_FOURCC(*'DIVX')
    # out = cv2.VideoWriter('output.avi',fourcc, 20.0, (640,480))
    # out = cv2.VideoWriter('./output.avi', -1, 20.0, (640, 480))
    # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
    out = cv2.VideoWriter('output_4.avi',
                          cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 20,
                          (640, 480))

    # Start fetching Buffer
    print('Starting Buffer...')
    i = 1000
    while (i):
        cam.wait_for_frames()
        image_1 = cam.color[:, :, ::-1]
        gray_1 = cv2.cvtColor(image_1, cv2.COLOR_RGB2GRAY)
        face_cascade = cv2.CascadeClassifier(
            'detector_architectures/haarcascade_frontalface_default.xml')
        faces_1 = face_cascade.detectMultiScale(gray_1, 1.1, 5)

        # make a copy of the original image to plot detections on
        image_with_detections_1 = image_1.copy()

        # loop over the detected faces, mark the image where each face is found
        for (x, y, w, h) in faces_1:
            # face = gray_1
            roi = gray_1[y:y + int(h), x:x + int(w)]
            org_shape = roi.shape
            roi = roi / 255.0

            roi = cv2.resize(roi, (224, 224))
            # image_plot = np.copy(roi)
            roi = roi.reshape(roi.shape[0], roi.shape[1], 1)
            roi = np.transpose(roi, (2, 0, 1))
            roi = torch.from_numpy(roi)
            roi = Variable(roi)
            roi = roi.type(torch.cuda.FloatTensor)
            roi = roi.unsqueeze(0)
            predicted_key_pts = net(roi)
            predicted_key_pts = predicted_key_pts.view(68, -1)
            predicted_key_pts = predicted_key_pts.data
            predicted_key_pts = predicted_key_pts.cpu().numpy()
            predicted_key_pts = predicted_key_pts * 50.0 + 100

            predicted_key_pts[:, 0] = predicted_key_pts[:, 0] * org_shape[
                0] / 224 + x
            predicted_key_pts[:, 1] = predicted_key_pts[:, 1] * org_shape[
                1] / 224 + y

            # cv2.rectangle(image_with_detections_1, (x, y), (x + w, y + h), (0, 0, 255), 3)

            for (x_point, y_point) in zip(predicted_key_pts[:, 0],
                                          predicted_key_pts[:, 1]):
                cv2.circle(image_with_detections_1, (x_point, y_point), 3,
                           (0, 255, 0), -1)

        # current_color = cam.color[:, :, ::-1]
        # current_depth = cam.depth * scale
        # current_cad = cam.cad[:, :, ::-1]
        # current_dac = cam.dac * scale
        out.write(image_with_detections_1)
        cv2.imshow('Color', image_with_detections_1)
        # cv2.imshow('Depth', current_depth)
        # cv2.imshow('CAD', current_cad)
        # cv2.imshow('DAC', current_dac)

        # if flag_save_frames:
        #     num = format(file_num, '08')
        #     cv2.imwrite('./data/depth/' + str(num) + '.png', cam.depth)
        #     cv2.imwrite('./data/color/' + str(num) + '.png', current_color)
        #     cv2.imwrite('./data/dac/' + str(num) + '.png', cam.dac)
        #     cv2.imwrite('./data/cad/' + str(num) + '.png', current_cad)
        #     file_num += 1
        i = i - 1
        k = cv2.waitKey(1)
        if k == ord('q'):
            print('Q Pressed...\nEnding execution')
            break
        if k == ord('f'):
            if flag_save_frames:
                print('F Pressed...\nStopped fetching frames...')
                flag_save_frames = False
            else:
                print('F Pressed...\nStarted fetching frames...')
                flag_save_frames = True

    cam.stop()
    # pyrs.stop()
    out.release()
    serv.stop()
    return 0
コード例 #20
0
import cv2
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from data_load import FacialKeypointsDataset, Rescale, RandomCrop, Normalize, ToTensor
from models import Net

net = Net()
net.load_state_dict(
    torch.load(
        r'C:\Users\Semanti Basu\Documents\OneDrive_2020-02-19\3D Ceaser dataset\Image and point generation\Image and point generation\frontaltrainedmodel_10epoch.pth'
    ))

## print out your net and prepare it for testing (uncomment the line below)
net.eval()
data_transform = transforms.Compose(
    [Rescale(225), RandomCrop(224),
     Normalize(), ToTensor()])
transformed_dataset = FacialKeypointsDataset(
    csv_file=
    r'C:\Users\Semanti Basu\Documents\OneDrive_2020-02-19\3D Ceaser dataset\Image and point generation\Image and point generation\frontalpoints.csv',
    root_dir=
    r'C:\Users\Semanti Basu\Documents\OneDrive_2020-02-19\3D Ceaser dataset\Image and point generation\Image and point generation\ceasar_mat',
    transform=data_transform)
# load training data in batches
batch_size = 10

train_loader = DataLoader(transformed_dataset,
                          batch_size=batch_size,
                          shuffle=True,
コード例 #21
0
class NNet():
    """
    Wrapper to manage neural net.
    """
    def __init__(self, args):
        self.args = args
        self.num_channels = NUM_CHANNELS

        if args.netType == 1:
            self.net = Net(self.num_channels, args)
        elif args.netType == 2:
            self.net = Net2(self.num_channels, args)

        if args.cuda:
            self.net = self.net.cuda()

        self.load_dataset_from_folder()
        self.writer = SummaryWriter()
        self.unique_tok = str(time.time())
        self.init_weights()

    def init_weights(self):
        """
        Initialize by Xavier weights
        """
        self.net.apply(init_weights)

    def load_dataset_from_folder(self):
        """
        Load complete dataset
        """
        all_data_path = self.args.all_data_path
        validation_split_size = self.args.validation_split_size
        batch_size = self.args.batch_size
        num_workers = self.args.num_workers
        shuffle = self.args.shuffle

        all_data = ImageFolder(root=all_data_path, transform=TRANSFORM)

        classes = all_data.classes
        self.classes = classes

        validation_size = int(validation_split_size * len(all_data))
        test_size = int(validation_split_size * len(all_data))
        train_size = len(all_data) - 2 * validation_size
        train_dataset, val_dataset, test_dataset = random_split(
            all_data, [train_size, validation_size, test_size])

        training_data_loader = DataLoader(train_dataset,
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=shuffle)

        validation_dataset_loader = DataLoader(val_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=shuffle)

        test_dataset_loader = DataLoader(test_dataset,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=shuffle)

        self.train_loader = training_data_loader
        self.val_loader = validation_dataset_loader
        self.test_loader = test_dataset_loader

    def train(self):
        """
        Train Neural Net
        """

        if self.args.optim == 'RMSprop':
            optimizer = optim.RMSprop(self.net.parameters(),
                                      lr=self.args.lr,
                                      momentum=self.args.momentum,
                                      weight_decay=self.args.l2_regularization)
        elif self.args.optim == 'SGD':
            optimizer = optim.SGD(self.net.parameters(),
                                  lr=self.args.lr,
                                  momentum=self.args.momentum)
        elif self.args.optim == 'Adam':
            optimizer = optim.Adam(self.net.parameters(), lr=self.args.lr)

        criterion = nn.CrossEntropyLoss()

        # scheduler = optim.lr_scheduler.StepLR(
        #     optimizer, step_size=self.args.scheduler_step_size, gamma=self.args.scheduler_gamma)

        self.net.train()

        for epoch in range(self.args.epoch):
            start_time = time.time()

            running_loss_t = 0.0
            num_batches = 0

            y_true = []
            y_pred = []
            # print('Epoch: {} , LR: {}'.format(epoch+1, scheduler.get_lr()))

            for data in tqdm(self.train_loader):
                inputs, labels = data
                labels_cp = labels.clone()

                # imshow(torchvision.utils.make_grid(inputs[:,:3,:,:]))

                if len(inputs) < 2:
                    continue

                if self.args.cuda:
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                outputs = self.net(inputs)

                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs, 1)
                predicted = predicted.cpu()
                for i, pred in enumerate(predicted):
                    y_pred.append(pred)
                    y_true.append(labels_cp[i])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss_t += loss.item()
                num_batches += 1

            end_time = time.time()

            train_f1 = f1_score(y_true, y_pred, average='weighted')

            # scheduler.step()

            self.save(epoch + 1)
            self.writer.add_scalar('Loss/train', running_loss_t / num_batches,
                                   epoch + 1)
            self.writer.add_scalar('F1/train', train_f1, epoch + 1)

            loss_v, val_f1 = self.get_validation_loss(criterion)

            self.writer.add_scalar('Loss/val', loss_v, epoch + 1)
            self.writer.add_scalar('F1/val', val_f1, epoch + 1)

            print(
                "Epoch {} Time {:.2f}s Train-Loss {:.3f} Val-Loss {:.3f} Train-F1 {:.3f} Val-F1 {:.3f}"
                .format(epoch + 1, end_time - start_time,
                        running_loss_t / num_batches, loss_v, train_f1,
                        val_f1))

    def get_validation_loss(self, criterion):
        """
        Check validation loss
        """
        running_loss = 0.0
        num_batches = 0

        self.net.eval()
        y_true = []
        y_pred = []

        with torch.no_grad():
            for data in tqdm(self.val_loader):
                images, labels = data
                labels_cp = labels.clone()

                if self.args.cuda:
                    images = images.cuda()
                    labels = labels.cuda()

                outputs = self.net(images)

                _, predicted = torch.max(outputs, 1)
                predicted = predicted.cpu()
                for i, pred in enumerate(predicted):
                    y_pred.append(pred)
                    y_true.append(labels_cp[i])

                loss = criterion(outputs, labels)
                running_loss += loss.item()
                num_batches += 1

        self.net.train()
        val_f1 = f1_score(y_true, y_pred, average='weighted')

        return running_loss / num_batches, val_f1

    def get_test_accuracy(self):
        """
        Check overall accuracy of model
        """
        y_true = []
        y_pred = []
        class_correct = list(0. for i in range(4))
        class_total = list(0. for i in range(4))

        with torch.no_grad():
            for data in tqdm(self.test_loader):
                images, labels = data
                labels_cp = labels.clone()
                if self.args.cuda:
                    images = images.cuda()
                    labels = labels.cuda()
                outputs = self.net(images)
                _, predicted = torch.max(outputs, 1)
                predicted = predicted.cpu()
                for i, pred in enumerate(predicted):
                    y_pred.append(pred)
                    y_true.append(labels_cp[i])
                c = (predicted == labels_cp).squeeze()

                for i in range(min(self.args.batch_size, len(labels_cp))):
                    label = labels_cp[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

        print("Test F1: ", f1_score(y_true, y_pred, average='weighted'))

    def save(self, epochs, folder_path="../models/"):
        """
        Save Model
        """
        dict_save = {'params': self.net.state_dict(), 'classes': self.classes}
        name = folder_path + self.unique_tok + '_' + str(epochs) + '.model'
        torch.save(dict_save, name)
        print('Model saved at {}'.format(name))
        return name

    def load(self, path):
        """
        Load a saved model
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        dict_load = torch.load(path, map_location=torch.device(device))
        self.net.load_state_dict(dict_load['params'])
        return dict_load['classes']

    def predict(self, inp):
        """
        Predict using net
        """

        if self.args.cuda:
            inp = inp.cuda()

        self.net.eval()
        with torch.no_grad():
            vals = self.net(inp)
            print(vals)
            _, predicted = torch.max(vals, 1)
            predicted = predicted.cpu()
            result_class = self.classes[predicted]

        return result_class