Пример #1
0
    def validation(discriminator, send_stats=False, epoch=0):
        print('Validating model on {0} examples. '.format(
            len(validation_dataset)))
        discriminator_ = discriminator.eval()

        with torch.no_grad():
            pred_logits_list = []
            labels_list = []

            for (inp, labels, imgs_names) in tqdm(validation_loader):
                inp = Variable(inp.float(), requires_grad=False)
                labels = Variable(labels.long(), requires_grad=False)

                if hparams.dim3:
                    inp = inp.view(-1, 1, 640, 64)
                    inp = torch.cat([inp] * 3, dim=1)

                inp = inp.to(hparams.gpu_device)
                labels = labels.to(hparams.gpu_device)

                pred_logits = discriminator_(inp)

                pred_logits_list.append(pred_logits)
                labels_list.append(labels)

            pred_logits = torch.cat(pred_logits_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss = adversarial_loss(pred_logits, labels)

        return accuracy_metrics(
            labels.long(), pred_logits
        ), val_loss  #, plot_auc='train_val_'+str(epoch+1), plot_path=hparams.result_dir+'train_val_{}_'.format(epoch)), val_loss
Пример #2
0
    def validation(discriminator, send_stats=False, epoch=0):
        print('Validating model on {0} examples. '.format(
            len(validation_dataset)))
        discriminator_ = discriminator.eval()

        with torch.no_grad():
            pred_logits_list = []
            labels_list = []

            for (img, labels, imgs_names) in tqdm(validation_loader):
                img = Variable(img.float(), requires_grad=False)
                labels = Variable(labels.float(), requires_grad=False)

                img_ = img.to(hparams.gpu_device)
                labels = labels.to(hparams.gpu_device)

                pred_logits = discriminator_(img_)

                pred_logits_list.append(pred_logits)
                labels_list.append(labels)

            pred_logits = torch.cat(pred_logits_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss = adversarial_loss(pred_logits, labels)

        return accuracy_metrics(labels.long(), pred_logits), val_loss
Пример #3
0
def test(model_paths,
         data=(hparams.valid_csv, hparams.valid_dir),
         plot_auc='valid',
         plot_path=hparams.result_dir + 'valid',
         best_thresh=None,
         pred_csv=None):

    test_dataset = ChestData(
        data_csv=data[0],
        data_dir=data[1],
        augment=hparams.TTA,
        transform=transforms.Compose([
            transforms.Resize(hparams.image_shape),
            transforms.ToTensor(),
            #                             transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
        ]))

    test_loader = DataLoader(test_dataset,
                             batch_size=hparams.batch_size,
                             shuffle=False,
                             num_workers=4)

    discriminators = [
        Discriminator().to(hparams.gpu_device) for _ in model_paths
    ]
    if hparams.cuda:
        discriminators = [
            nn.DataParallel(discriminators[i], device_ids=hparams.device_ids)
            for i in range(len(model_paths))
        ]
    checkpoints = [
        torch.load(model_path, map_location=hparams.gpu_device)
        for model_path in model_paths
    ]
    for i in range(len(model_paths)):
        discriminators[i].load_state_dict(
            checkpoints[i]['discriminator_state_dict'])

    def put_eval(model):
        model = model.eval()
        model.training = hparams.eval_dp_on
        return model

    discriminators = [
        put_eval(discriminator) for discriminator in discriminators
    ]
    # print('Model loaded')

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    print('Testing model on {0} examples. '.format(len(test_dataset)))

    with torch.no_grad():
        pred_logits = torch.zeros((len(test_dataset), hparams.num_classes))
        if hparams.cuda:
            pred_logits = pred_logits.to(hparams.gpu_device)
        for _ in range(hparams.repeat_infer):
            labels_list = []
            img_names_list = []
            pred_logits_list = []
            for (img, labels, img_names) in tqdm(test_loader):
                img = Variable(img.float(), requires_grad=False)
                labels = Variable(labels.float(), requires_grad=False)
                if hparams.cuda:
                    img_ = img.to(hparams.gpu_device)
                    labels = labels.to(hparams.gpu_device)
                pred_logits_ = discriminators[0](img_)
                pred_logits_ = pred_logits_ * 0
                for discriminator in discriminators:
                    pred_logits_ += discriminator(img_)
                pred_logits_ = 1.0 * pred_logits_ / len(model_paths)

                pred_logits_list.append(pred_logits_)
                labels_list.append(labels)
                img_names_list += list(img_names)

            pred_logits += torch.cat(pred_logits_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

        pred_logits = 1.0 * pred_logits / hparams.repeat_infer
        _, pred_labels = torch.max(pred_logits, axis=1)

        f1, acc, conf_mat = accuracy_metrics(labels, pred_labels)

        print('== Test on -- '+str(model_paths)+' == \n f1 - {0:.4f}, acc - {1:.4f}'\
            .format(f1, acc))
    return f1
Пример #4
0
def test(model_path,
         data=(hparams.valid_csv, hparams.dev_file),
         plot_auc='valid',
         plot_path=hparams.result_dir + 'valid',
         best_thresh=None):

    test_dataset = AudioData(data_csv=data[0],
                             data_file=data[1],
                             ds_type='valid',
                             augment=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                             ]))

    test_loader = DataLoader(test_dataset,
                             batch_size=hparams.batch_size,
                             shuffle=True,
                             num_workers=2)

    discriminator = Discriminator().to(hparams.gpu_device)
    if hparams.cuda:
        discriminator = nn.DataParallel(discriminator,
                                        device_ids=hparams.device_ids)
    checkpoint = torch.load(model_path, map_location=hparams.gpu_device)
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

    discriminator = discriminator.eval()
    # print('Model loaded')

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    print('Testing model on {0} examples. '.format(len(test_dataset)))

    with torch.no_grad():
        pred_logits_list = []
        labels_list = []
        img_names_list = []
        # for _ in range(hparams.repeat_infer):
        for (inp, labels, img_names) in tqdm(test_loader):
            inp = Variable(inp.float(), requires_grad=False)
            labels = Variable(labels.long(), requires_grad=False)

            inp = inp.to(hparams.gpu_device)
            labels = labels.to(hparams.gpu_device)

            if hparams.dim3:
                inp = inp.view(-1, 1, 640, 64)
                inp = torch.cat([inp] * 3, dim=1)

            pred_logits = discriminator(inp)

            pred_logits_list.append(pred_logits)
            labels_list.append(labels)
            img_names_list.append(img_names)

        pred_logits = torch.cat(pred_logits_list, dim=0)
        labels = torch.cat(labels_list, dim=0)

        auc, f1, acc, conf_mat = accuracy_metrics(labels,
                                                  pred_logits,
                                                  plot_auc=plot_auc,
                                                  plot_path=plot_path,
                                                  best_thresh=best_thresh)

        fig = plot_cf(conf_mat)
        plt.savefig(hparams.result_dir + 'test_conf_mat.png')
        res = ' -- avg_acc - {0:.4f}'.format(acc['avg'])
        for it in range(10):
            res += ', acc_{}'.format(
                hparams.id_to_class[it]) + ' - {0:.4f}'.format(acc[it])
        print('== Test on -- ' + model_path + res)
        # print('== Test on -- '+model_path+' == \n\
        #     auc_{0} - {10:.4f}, auc_{1} - {11:.4f}, auc_{2} - {12:.4f}, auc_{3} - {13:.4f}, auc_{4} - {14:.4f}, auc_{5} - {15:.4f}, auc_{6} - {16:.4f}, auc_{7} - {17:.4f}, auc_{8} - {18:.4f}, auc_{9} - {19:.4f}, auc_micro - {20:.4f}, auc_macro - {21:.4f},\n\
        #     acc_{0} - {22:.4f}, acc_{1} - {23:.4f}, acc_{2} - {24:.4f}, acc_{3} - {25:.4f}, acc_{4} - {26:.4f}, acc_{5} - {27:.4f}, acc_{6} - {28:.4f}, acc_{7} - {29:.4f}, acc_{8} - {30:.4f}, acc_{9} - {31:.4f}, acc_avg - {32:.4f},\n\
        #     f1_{0} - {33:.4f}, f1_{1} - {34:.4f}, f1_{2} - {35:.4f}, f1_{3} - {36:.4f}, f1_{4} - {37:.4f}, f1_{5} - {38:.4f}, f1_{6} - {39:.4f}, f1_{7} - {40:.4f}, f1_{8} - {41:.4f}, f1_{9} - {42:.4f}, f1_micro - {42:.4f}, f1_macro - {43:.4f}, =='.\
        #     format([hparams.id_to_class[it] for it in range(10)]+[auc[it] for it in range(10)]+[auc['micro'], auc['macro']]+[acc[it] for it in range(10)]+[acc['avg']]+[f1[it] for it in range(10)]+[f1['micro'], f1['macro']]))
    return acc['avg']
Пример #5
0
def train(resume=False):

    writer = SummaryWriter('../runs/' + hparams.exp_name)

    for k in hparams.__dict__.keys():
        writer.add_text(str(k), str(hparams.__dict__[k]))

    train_dataset = AudioData(
        data_csv=hparams.train_csv,
        data_file=hparams.dev_file,
        ds_type='train',  # augment=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ]))

    validation_dataset = AudioData(data_csv=hparams.valid_csv,
                                   data_file=hparams.dev_file,
                                   ds_type='valid',
                                   augment=False,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                   ]))

    # train_sampler = WeightedRandomSampler()

    train_loader = DataLoader(train_dataset,
                              batch_size=hparams.batch_size,
                              shuffle=True,
                              num_workers=2)

    validation_loader = DataLoader(validation_dataset,
                                   batch_size=hparams.batch_size,
                                   shuffle=True,
                                   num_workers=2)

    print('loaded train data of length : {}'.format(len(train_dataset)))

    adversarial_loss = torch.nn.CrossEntropyLoss().to(hparams.gpu_device)
    discriminator = Discriminator().to(hparams.gpu_device)

    if hparams.cuda:
        discriminator = nn.DataParallel(discriminator,
                                        device_ids=hparams.device_ids)

    params_count = 0
    for param in discriminator.parameters():
        params_count += np.prod(param.size())
    print('Model has {0} trainable parameters'.format(params_count))

    if not hparams.pretrained:
        discriminator.apply(weights_init_normal)

    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=hparams.learning_rate)

    scheduler_D = ReduceLROnPlateau(optimizer_D,
                                    mode='min',
                                    factor=0.3,
                                    patience=4,
                                    verbose=True,
                                    cooldown=0)

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    def validation(discriminator, send_stats=False, epoch=0):
        print('Validating model on {0} examples. '.format(
            len(validation_dataset)))
        discriminator_ = discriminator.eval()

        with torch.no_grad():
            pred_logits_list = []
            labels_list = []

            for (inp, labels, imgs_names) in tqdm(validation_loader):
                inp = Variable(inp.float(), requires_grad=False)
                labels = Variable(labels.long(), requires_grad=False)

                if hparams.dim3:
                    inp = inp.view(-1, 1, 640, 64)
                    inp = torch.cat([inp] * 3, dim=1)

                inp = inp.to(hparams.gpu_device)
                labels = labels.to(hparams.gpu_device)

                pred_logits = discriminator_(inp)

                pred_logits_list.append(pred_logits)
                labels_list.append(labels)

            pred_logits = torch.cat(pred_logits_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss = adversarial_loss(pred_logits, labels)

        return accuracy_metrics(
            labels.long(), pred_logits
        ), val_loss  #, plot_auc='train_val_'+str(epoch+1), plot_path=hparams.result_dir+'train_val_{}_'.format(epoch)), val_loss

    print('Starting training.. (log saved in:{})'.format(hparams.exp_name))
    start_time = time.time()
    best_valid_acc = 0

    # print(model)
    for epoch in range(hparams.num_epochs):
        train_logits = []
        train_labels = []
        for batch, (inp, labels, imgs_name) in enumerate(tqdm(train_loader)):

            inp = Variable(inp.float(), requires_grad=False)
            labels = Variable(labels.long(), requires_grad=False)

            inp = inp.to(hparams.gpu_device)
            labels = labels.to(hparams.gpu_device)

            if hparams.dim3:
                inp = inp.view(-1, 1, 640, 64)
                inp = torch.cat([inp] * 3, dim=1)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            pred_logits = discriminator(inp)
            train_logits.append(pred_logits)
            train_labels.append(labels)

            d_loss = adversarial_loss(pred_logits, labels)

            d_loss.backward()
            optimizer_D.step()

            writer.add_scalar('d_loss',
                              d_loss.item(),
                              global_step=batch + epoch * len(train_loader))

            # if batch % hparams.print_interval == 0:
            #     pred_labels = (pred_logits >= hparams.thresh)
            #     pred_labels = pred_labels.float()
            #     auc, f1, acc, _, _ = accuracy_metrics(pred_labels, labels.long(), pred_logits)
            #     print('[Epoch - {0:.1f}, batch - {1:.3f}, d_loss - {2:.6f}, acc - {3:.4f}, f1 - {4:.5f}, auc - {5:.4f}]'.\
            #     format(1.0*epoch, 100.0*batch/len(train_loader), d_loss.item(), acc['avg'], f1[hparams.avg_mode], auc[hparams.avg_mode]))

        (val_auc, val_f1, val_acc,
         val_conf_mat), val_loss = validation(discriminator, epoch=epoch)

        train_logits = torch.cat(train_logits, dim=0)
        train_labels = torch.cat(train_labels, dim=0)

        train_auc, train_f1, train_acc, train_conf_mat = accuracy_metrics(
            train_labels.long(), train_logits)

        fig = plot_cf(val_conf_mat)
        writer.add_figure('val_conf', fig, global_step=epoch)
        plt.close(fig)
        for lbl in range(hparams.num_classes):
            writer.add_scalar('val_f1_{}'.format(hparams.id_to_class[lbl]),
                              val_f1[lbl],
                              global_step=epoch)
            writer.add_scalar('val_auc_{}'.format(hparams.id_to_class[lbl]),
                              val_auc[lbl],
                              global_step=epoch)
            writer.add_scalar('val_acc_{}'.format(hparams.id_to_class[lbl]),
                              val_acc[lbl],
                              global_step=epoch)
        writer.add_scalar('val_f1_{}'.format('micro'),
                          val_f1['micro'],
                          global_step=epoch)
        writer.add_scalar('val_auc_{}'.format('micro'),
                          val_auc['micro'],
                          global_step=epoch)
        writer.add_scalar('val_f1_{}'.format('macro'),
                          val_f1['macro'],
                          global_step=epoch)
        writer.add_scalar('val_auc_{}'.format('macro'),
                          val_auc['macro'],
                          global_step=epoch)
        writer.add_scalar('val_loss', val_loss, global_step=epoch)
        writer.add_scalar('val_f1',
                          val_f1[hparams.avg_mode],
                          global_step=epoch)
        writer.add_scalar('val_auc',
                          val_auc[hparams.avg_mode],
                          global_step=epoch)
        writer.add_scalar('val_acc', val_acc['avg'], global_step=epoch)
        scheduler_D.step(val_loss)
        writer.add_scalar('learning_rate',
                          optimizer_D.param_groups[0]['lr'],
                          global_step=epoch)

        # torch.save({
        #     'epoch': epoch,
        #     'discriminator_state_dict': discriminator.state_dict(),
        #     'optimizer_D_state_dict': optimizer_D.state_dict(),
        #     }, hparams.model+'.'+str(epoch))
        if best_valid_acc <= val_acc['avg']:
            best_valid_acc = val_acc['avg']
            fig = plot_cf(val_conf_mat)
            writer.add_figure('best_val_conf', fig, global_step=epoch)
            plt.close(fig)
            torch.save(
                {
                    'epoch': epoch,
                    'discriminator_state_dict': discriminator.state_dict(),
                    'optimizer_D_state_dict': optimizer_D.state_dict(),
                }, hparams.model + '.best')
            print('best model on validation set saved.')

        print('[Epoch - {0:.1f} ---> train_acc - {1:.4f}, current_lr - {2:.6f}, val_loss - {3:.4f}, best_val_acc - {4:.4f}, val_acc - {5:.4f}, val_f1 - {6:.4f}] - time - {7:.1f}'\
            .format(1.0*epoch, train_acc['avg'], optimizer_D.param_groups[0]['lr'], val_loss, best_valid_acc, val_acc['avg'], val_f1[hparams.avg_mode], time.time()-start_time))
        start_time = time.time()
Пример #6
0
def test(model_paths, data=(hparams.valid_csv, hparams.valid_dir), plot_auc='valid', plot_path=hparams.result_dir+'valid', best_thresh=None, pred_csv=None):

    test_dataset = ChestData(data_csv=data[0], data_dir=data[1], augment=hparams.TTA,
                        transform=transforms.Compose([
                            transforms.Resize(hparams.image_shape),
                            transforms.ToTensor(),
#                             transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
                        ]))

    test_loader = DataLoader(test_dataset, batch_size=hparams.batch_size,
                            shuffle=False, num_workers=4)


    discriminators = [Discriminator().to(hparams.gpu_device) for _ in model_paths]
    if hparams.cuda:
        discriminators = [nn.DataParallel(discriminators[i], device_ids=hparams.device_ids) for i in range(len(model_paths))]
    checkpoints = [torch.load(model_path, map_location=hparams.gpu_device) for model_path in model_paths]
    for i in range(len(model_paths)):
#         print(checkpoints[i])
        discriminators[i].load_state_dict(checkpoints[i]['discriminator_state_dict'])

    def put_eval(model):
        model = model.eval()
#         model.training = hparams.eval_dp_on
        return model
    discriminators = [put_eval(discriminator) for discriminator in discriminators]
    # print('Model loaded')

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    print('Testing model on {0} examples. '.format(len(test_dataset)))

    with torch.no_grad():
        pred_logits = torch.zeros((len(test_dataset), hparams.num_classes))
        if hparams.cuda:
            pred_logits = pred_logits.to(hparams.gpu_device)
        for _ in range(hparams.repeat_infer):
            labels_list = []
            img_names_list = []
            pred_logits_list = []
            for (img, labels, img_names) in tqdm(test_loader):
                img = Variable(img.float(), requires_grad=False)
                labels = Variable(labels.float(), requires_grad=False)
                if hparams.cuda:
                    img_ = img.to(hparams.gpu_device)
                    labels = labels.to(hparams.gpu_device)
                pred_logits_ = discriminators[0](img_)
                pred_logits_ = pred_logits_*0
                for discriminator in discriminators:
                    pred_logits_ += discriminator(img_)
                pred_logits_ = 1.0*pred_logits_/len(model_paths)

                pred_logits_list.append(pred_logits_)
                labels_list.append(labels)
                img_names_list += list(img_names)

            pred_logits += torch.cat(pred_logits_list, dim=0)
            labels = torch.cat(labels_list, dim=0)
        
        pred_logits = 1.0*pred_logits/hparams.repeat_infer
        print(plot_path)
             
#         pred_logits = aggr_preds(img_names_list, pred_logits, 'min')
#         labels = aggr_preds(img_names_list, labels, 'min')

        auc, f1, acc, conf_mat, best_thresh = accuracy_metrics(labels, pred_logits, plot_auc=plot_auc, plot_path=plot_path, best_thresh=best_thresh)
        if hparams.cuda:
            pred_logits = pred_logits.cpu()
        pred_logits = pred_logits.numpy()
        print(best_thresh)

        pred_labels = 1*(pred_logits > np.array(best_thresh))
        if pred_csv:
            data = {'Path': img_names_list}
            for lbl in range(14):
                data[hparams.id_to_class[lbl]] = pred_labels[:,lbl]
            df = pd.DataFrame(data)
            df.to_csv('../results/predictions_{}.csv'.format(pred_csv), index=False)
            print('predictions saved to "../results/predictions_{}.csv"'.format(pred_csv))

        print('== Test on -- '+str(model_paths)+' == \n\
            auc_{0} - {5:.4f}, auc_{1} - {6:.4f}, auc_{2} - {7:.4f}, auc_{3} - {8:.4f}, auc_{4} - {9:.4f}, auc_micro - {10:.4f}, auc_macro - {11:.4f},\n\
            acc_{0} - {12:.4f}, acc_{1} - {13:.4f}, acc_{2} - {14:.4f}, acc_{3} - {15:.4f}, acc_{4} - {16:.4f}, acc_avg - {17:.4f},\n\
            f1_{0} - {18:.4f}, f1_{1} - {19:.4f}, f1_{2} - {20:.4f}, f1_{3} - {21:.4f}, f1_{4} - {22:.4f}, f1_micro - {23:.4f}, f1_macro - {24:.4f},\n\
            thresh_{0} - {25:4f}, thresh_{1} - {26:4f}, thresh_{2} - {27:4f}, thresh_{3} - {28:4f}, thresh_{4} - {29:4f} =='.\
            format(hparams.id_to_class[0], hparams.id_to_class[1], hparams.id_to_class[2], hparams.id_to_class[3], hparams.id_to_class[4], auc[0], auc[1], auc[2], auc[3], auc[4], auc['micro'], auc['macro'], acc[0], acc[1], acc[2], acc[3], acc[4], acc['avg'],
            f1[0], f1[1], f1[2], f1[3], f1[4], f1['micro'], f1['macro'], best_thresh[0], best_thresh[1], best_thresh[2], best_thresh[3], best_thresh[4]))
    return auc['micro']
Пример #7
0
def test(model_path=hparams.model, send_stats=False):

    test_dataset = APTOSData(csv_file=hparams.train_csv,
                        root_dir=hparams.train_dir,
                        split=0.95,
                        test_split=0.95,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                        ]),
                        ds_type='test', file_format='.png')

    extra_test_dataset = APTOSData(csv_file=hparams.dg_test_csv,
                        root_dir=hparams.dg_test_dir,
                        split=0,
                        test_split=0.0,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                        ]),
                        ds_type='test', file_format='.jpg')

    test_loader = DataLoader(test_dataset, batch_size=64,
                            shuffle=False, num_workers=2)

    extra_test_loader = DataLoader(extra_test_dataset, batch_size=64,
                            shuffle=False, num_workers=2)

    test_loaders = [test_loader, extra_test_loader]

    # print('loaded test data of length :'+str(len(test_loader)))

    if hparams.cuda:
        model = Rater(hparams.image_shape, hparams.num_classes, pretrained=False).cuda(hparams.gpu_device)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model = Rater(hparams.image_shape, hparams.num_classes, pretrained=False)
        checkpoint = torch.load(model_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])

    # print('Model loaded')

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    rounder = OptimizedRounder()

    print('Testing model on {0} examples. '.format(len(test_loader)))
    with open('submission.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['id_code', 'diagnosis'])
        with torch.no_grad():
            pred_labels_list = []
            labels_list = []
            img_names_list = []
            for loader in tqdm(test_loaders):
                for i, (img, labels, img_names) in enumerate(loader):
                    img = Variable(img.type(Tensor), requires_grad=False)
                    labels = Variable(labels.float(), requires_grad=False)
                    if hparams.cuda:
                        img = img.cuda(hparams.gpu_device)
                        labels = labels.cuda(hparams.gpu_device)
                    pred_logits, _ = model(img)
                    pred_labels = torch.tensor(rounder.predict(pred_logits.view(-1), hparams.coefficients))
                    img_names_list += list(img_names)
                    pred_labels_list.append(pred_labels.view(-1))
                    labels_list.append(labels.view(-1))

            pred_labels = torch.cat(pred_labels_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            for i in range(len(img_names_list)):
                writer.writerow([img_names_list[i], pred_labels[i].item()])

    if send_stats:
        kappa, precision, recall, f1, accuracy, precision_list, recall_list, f1_list, accuracy_list = accuracy_metrics(pred_labels, labels, True)
    kappa, precision, recall, f1, accuracy = accuracy_metrics(pred_labels, labels, False)

    print('== Test on -- '+model_path+' == kappa - {0:.4f}, precision - {1:.4f}, recall - {2:.4f}, f1 - {3:.4f}, accuracy - {4:.4f} =='.format(kappa, precision, recall, f1, accuracy))

    return kappa