コード例 #1
0
ファイル: test.py プロジェクト: kejndan/PruningResNet20
def evaluate(data_loader,
             model,
             epoch,
             cfg,
             mode='train',
             with_logs=True,
             with_print=True):
    if with_print:
        print(f'Evaluate on {mode} data')
    model.eval()
    nb_classes = data_loader.dataset.nb_classes
    conf_matrix = np.zeros((nb_classes, nb_classes))
    accuracy_sum = 0

    total_iter = len(data_loader)

    for iter, batch in enumerate(data_loader):
        images = batch[0].to(cfg.device)
        labels = batch[1].type(torch.LongTensor).to(cfg.device)

        with torch.no_grad():
            output = model(images)
        _, predict = torch.max(output, 1)
        y_pred, y_true = predict.cpu().numpy(), labels.cpu().numpy()
        accuracy_sum += np.sum(y_pred == y_true)
        if cfg.plot_confusion_matrix:
            for i in range(len(images)):
                conf_matrix[y_true[i], y_pred[i]] += 1

        if iter % 50 == 0:
            if with_print:
                print(f'Epoch: {epoch}. Batchs {iter} of {total_iter}.')

    accuracy = accuracy_sum / len(data_loader.dataset)
    if with_print:
        print(f'Epoch:{epoch}. {mode} accuracy {accuracy}')
    if with_logs:
        log_metric(f'eval/accuracy_{mode}', epoch, accuracy, cfg)

    if cfg.logger is not None:
        labels = data_loader.dataset.name_classes
        cfg.logger.report_scalar('eval', f'accuracy_{mode}', accuracy, epoch)
        if cfg.plot_confusion_matrix:
            cfg.logger.report_matrix('ConfusionMatrix',
                                     f'Epoch: {epoch}',
                                     conf_matrix,
                                     epoch,
                                     xlabels=labels,
                                     ylabels=labels)
    return accuracy
コード例 #2
0
def training_epoch(data_loader, model, criterion, optimizer, epoch, cfg):
    print('Train')
    model.train()
    total_iter = len(data_loader)
    for iter, batch in enumerate(data_loader):
        images = batch[0].to(cfg.device)
        labels = batch[1].type(torch.LongTensor).to(cfg.device)

        output = model(images)
        loss = criterion(output, labels, reduction='mean')
        log_metric('loss/train_loss_finetune', total_iter * epoch + iter,
                   loss.item(), cfg)
        if cfg.logger is not None:
            cfg.logger.report_scalar('loss', 'train_loss_finetune',
                                     loss.item(), total_iter * epoch + iter)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iter % 50 == 0:
            print(f'Epoch: {epoch}. Iteration {iter} of {total_iter}.')
コード例 #3
0
def train(model, criterion, optimizer, cfg, manual_load=None):
    ds_train = CIFAR10(cfg.path_to_dataset, cfg, transform_mode='train')
    dl_train = DataLoader(ds_train, batch_size=cfg.batch_size, shuffle=True)

    if cfg.evaluate_on_validation_data:
        ds_valid = CIFAR10(cfg.path_to_dataset,
                           cfg,
                           work_mode='valid',
                           transform_mode='test')
        dl_valid = DataLoader(ds_valid,
                              batch_size=cfg.batch_size,
                              shuffle=True)
    else:
        ds_test = CIFAR10(cfg.path_to_dataset,
                          cfg,
                          work_mode='test',
                          transform_mode='test')
        dl_test = DataLoader(ds_test, batch_size=cfg.batch_size, shuffle=True)

    if cfg.load_save:
        model, optimizer, start_epoch, max_accuracy = load_model(
            os.path.join(cfg.path_to_saves, cfg.name_save), model, cfg,
            optimizer)

    elif manual_load is not None:
        optimizer = manual_load['optimizer']
        start_epoch = manual_load['epoch']
        max_accuracy = manual_load['max_accuracy']

    else:
        start_epoch = 0
        max_accuracy = 0
    if cfg.use_lr_scheduler:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            min_lr=1e-8,
            patience=cfg.ROP_patience,
            factor=cfg.ROP_factor)
        optimizer.zero_grad()
        optimizer.step()

    for epoch in range(start_epoch, cfg.total_epochs):
        training_epoch(dl_train, model, criterion, optimizer, epoch, cfg)
        if cfg.evaluate_on_train_data:
            evaluate(dl_train, model, epoch, cfg, mode='train')

        if cfg.evaluate_on_validation_data:
            accuracy = evaluate(dl_valid, model, epoch, cfg, mode='validation')
        else:
            accuracy = evaluate(dl_test, model, epoch, cfg, mode='test')
        if cfg.use_lr_scheduler:
            lr_scheduler.step(accuracy)

            log_metric(f'loss/learning_rate', epoch,
                       optimizer.param_groups[0]['lr'], cfg)
            if cfg.logger is not None:
                cfg.logger.report_scalar('learning_rate', 'all',
                                         optimizer.param_groups[0]['lr'],
                                         epoch)

        if cfg.save_models or cfg.save_best_model:
            state = {
                'model': model.state_dict(),
                'epoch': epoch,
                'best_accuracy': max_accuracy,
                'opt': optimizer.state_dict()
            }
            if cfg.save_models:
                torch.save(
                    state,
                    os.path.join(cfg.path_to_saves, f'checkpoint_{epoch}'))

            if cfg.save_best_model and max_accuracy < accuracy:
                torch.save(state,
                           os.path.join(cfg.path_to_saves, f'best_checkpoint'))
                max_accuracy = accuracy

            if os.path.exists(
                    os.path.join(cfg.path_to_saves,
                                 f'checkpoint_{epoch - 3}')):
                os.remove(
                    os.path.join(cfg.path_to_saves, f'checkpoint_{epoch - 3}'))
コード例 #4
0
os.makedirs(output_folder_path, exist_ok=True)
os.makedirs(f'{output_folder_path}/models', exist_ok=True)
os.makedirs(f'{output_folder_path}/figures', exist_ok=True)
os.makedirs(f'{output_folder_path}/tb_logs', exist_ok=True)
os.makedirs(f'{output_folder_path}/submissions', exist_ok=True)

with open(args.config_file) as config_file:
    config = json.load(config_file)

cv_folds = int(config['crossvalidation']['folds']) if 'crossvalidation' in config else None
version = config['version']
basic_name = config['name'] % (version)
if cv_currentfold is not None:
    basic_name = f'{basic_name}-cvfold{cv_currentfold}'

log_metric('Model Basic Name', basic_name)
log_metric('CV Current Fold', cv_currentfold)
log_metric('CV Total Folds', cv_folds)


img_size_ori = 101
img_size_target = int(config['input']['size'])
img_channels = int(config['input']['channels'])
img_normalize = bool(config['input']['normalize'])
img_normalize_divide = 255 if img_normalize else 1
img_color_mode = 'rgb' if img_channels == 3 else 'grayscale'

upsample = lambda img: resize_image(img, img_size_target)
downsample = lambda img: resize_image(img, img_size_ori)

train_df = pd.read_csv(f'{data_folder_path}/train.csv', index_col='id', usecols=[0])
コード例 #5
0
            def compare(a_result_dict,
                        b_result_dict,
                        a_name,
                        b_name,
                        postfix=""):
                kl_a = kldiv(np.array(a_result_dict['test_a_mean']),
                             np.array(a_result_dict['test_a_log_var']))
                kl_b = kldiv(np.array(b_result_dict['test_b_mean']),
                             np.array(b_result_dict['test_b_log_var']))
                mean_a = np.mean(a_result_dict['test_a_mean'], axis=1)
                mean_b = np.mean(b_result_dict['test_b_mean'], axis=1)
                rec_a = a_result_dict['test_a_reconstloss']
                rec_b = b_result_dict['test_b_reconstloss']
                l2_mean_a = np.linalg.norm(a_result_dict['test_a_mean'],
                                           axis=1)
                l2_mean_b = np.linalg.norm(b_result_dict['test_b_mean'],
                                           axis=1)
                l2_var_a = np.linalg.norm(a_result_dict['test_a_log_var'],
                                          axis=1)
                l2_var_b = np.linalg.norm(b_result_dict['test_b_log_var'],
                                          axis=1)
                nll_a = kl_a + rec_a
                nll_b = kl_b + rec_b
                nllwrl_a = np.float32(args.reg_lambda) * kl_a + rec_a
                nllwrl_b = np.float32(args.reg_lambda) * kl_b + rec_b

                original_dim = np.float32(np.prod(args.original_shape))
                bpd_a = nll_a / original_dim
                bpd_b = nll_b / original_dim
                bpdwrl_a = nllwrl_a / original_dim
                bpdwrl_b = nllwrl_b / original_dim

                normed_nll_a = kl_a + (rec_a / original_dim)
                normed_nll_b = kl_b + (rec_b / original_dim)

                log_metric('test_mean_a{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(mean_a))
                log_metric('test_mean_b{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(mean_b))
                log_metric('test_var_a{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(np.exp(a_result_dict['test_a_log_var']),
                                     axis=(0, 1)))
                log_metric('test_var_b{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(np.exp(b_result_dict['test_b_log_var']),
                                     axis=(0, 1)))
                log_metric('test_rec_a{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(rec_a))
                log_metric('test_rec_b{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(rec_b))
                log_metric('test_kl_a{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(kl_a))
                log_metric('test_kl_b{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(kl_b))

                auc_kl = roc_auc_score(
                    np.concatenate([np.zeros_like(kl_a),
                                    np.ones_like(kl_b)]),
                    np.concatenate([kl_a, kl_b]))
                auc_mean = roc_auc_score(
                    np.concatenate(
                        [np.zeros_like(mean_a),
                         np.ones_like(mean_b)]),
                    np.concatenate([mean_a, mean_b]))
                auc_rec = roc_auc_score(
                    np.concatenate([np.zeros_like(rec_a),
                                    np.ones_like(rec_b)]),
                    np.concatenate([rec_a, rec_b]))
                auc_l2_mean = roc_auc_score(
                    np.concatenate(
                        [np.zeros_like(l2_mean_a),
                         np.ones_like(l2_mean_b)]),
                    np.concatenate([l2_mean_a, l2_mean_b]))
                auc_l2_var = roc_auc_score(
                    np.concatenate(
                        [np.zeros_like(l2_var_a),
                         np.ones_like(l2_var_b)]),
                    np.concatenate([l2_var_a, l2_var_b]))
                auc_nll = roc_auc_score(
                    np.concatenate([np.zeros_like(nll_a),
                                    np.ones_like(nll_b)]),
                    np.concatenate([nll_a, nll_b]))
                auc_normed_nll = roc_auc_score(
                    np.concatenate([
                        np.zeros_like(normed_nll_a),
                        np.ones_like(normed_nll_b)
                    ]), np.concatenate([normed_nll_a, normed_nll_b]))
                auc_nllwrl = roc_auc_score(
                    np.concatenate(
                        [np.zeros_like(nllwrl_a),
                         np.ones_like(nllwrl_b)]),
                    np.concatenate([nllwrl_a, nllwrl_b]))

                log_metric('auc_kl_{}_vs_{}{}'.format(args.test_dataset_a,
                                                      args.test_dataset_b,
                                                      postfix),
                           x=global_iters,
                           y=auc_kl)
                log_metric('auc_mean_{}_vs_{}{}'.format(
                    args.test_dataset_a, args.test_dataset_b, postfix),
                           x=global_iters,
                           y=auc_mean)
                log_metric('auc_rec_{}_vs_{}{}'.format(args.test_dataset_a,
                                                       args.test_dataset_b,
                                                       postfix),
                           x=global_iters,
                           y=auc_rec)
                log_metric('auc_l2_mean_{}_vs_{}{}'.format(
                    args.test_dataset_a, args.test_dataset_b, postfix),
                           x=global_iters,
                           y=auc_l2_mean)
                log_metric('auc_l2_var_{}_vs_{}{}'.format(
                    args.test_dataset_a, args.test_dataset_b, postfix),
                           x=global_iters,
                           y=auc_l2_var)
                log_metric('auc_neglog_likelihood_{}_vs_{}{}'.format(
                    args.test_dataset_a, args.test_dataset_b, postfix),
                           x=global_iters,
                           y=auc_nll)
                log_metric('auc_bpd{}'.format(postfix),
                           x=global_iters,
                           y=auc_nll)
                log_metric('auc_normed_nll{}'.format(postfix),
                           x=global_iters,
                           y=auc_normed_nll)
                log_metric('auc_nllwrl{}'.format(postfix),
                           x=global_iters,
                           y=auc_nllwrl)

                log_metric('test_bpd_a{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(bpd_a))
                log_metric('test_bpd_b{}'.format(postfix),
                           x=global_iters,
                           y=np.mean(bpd_b))

                if postfix == "":
                    log_metric('auc', x=global_iters, y=auc_nll)
                return kl_a, kl_b, rec_a, rec_b
コード例 #6
0
                        encoder_input: x,
                        reconst_latent_input: z_x,
                        sampled_latent_input: z_p
                    })

        if global_iters % args.frequency == 0:
            summary, = session.run([summary_op], feed_dict={encoder_input: x})
            summary_writer.add_summary(summary, global_iters)

        if (global_iters % args.frequency) == 0:
            if args.fixed_gen_as_negative:
                gamma_np, enc_loss_np, enc_l_ae_np, l_reg_z_np, l_reg_zr_ng_np, l_reg_zpp_ng_np, generator_loss_np, dec_l_ae_np, l_reg_zr_np, l_reg_zpp_np, lr_np, l_reg_zd_np, disc_loss_np, l_reg_z_fixed_gen_np = \
                    session.run([gamma, encoder_loss, l_ae, l_reg_z, l_reg_zr_ng, l_reg_zpp_ng, generator_loss, l_ae2, l_reg_zr, l_reg_zpp, learning_rate, l_reg_zd, discriminator_loss, l_reg_fixed_gen],
                                feed_dict={encoder_input: x, reconst_latent_input: z_x, sampled_latent_input: z_p, fixed_gen_input: x_fg})
                log_metric('l_reg_fixed_gen',
                           x=global_iters,
                           y=l_reg_z_fixed_gen_np)
            else:
                gamma_np, enc_loss_np, enc_l_ae_np, l_reg_z_np, l_reg_zr_ng_np, l_reg_zpp_ng_np, generator_loss_np, dec_l_ae_np, l_reg_zr_np, l_reg_zpp_np, lr_np, l_reg_zd_np, disc_loss_np = \
                    session.run([gamma, encoder_loss, l_ae, l_reg_z, l_reg_zr_ng, l_reg_zpp_ng, generator_loss, l_ae2, l_reg_zr, l_reg_zpp, learning_rate, l_reg_zd, discriminator_loss],
                                feed_dict=train_feed_dict)

            log_metric('disc_loss', x=global_iters, y=disc_loss_np)
            log_metric('l_reg_zd', x=global_iters, y=l_reg_zd_np)
            log_metric('enc_loss', x=global_iters, y=enc_loss_np)
            log_metric('l_ae', x=global_iters, y=enc_l_ae_np)
            log_metric('l_reg_z', x=global_iters, y=l_reg_z_np)
            log_metric('generator_loss', x=global_iters, y=generator_loss_np)
            log_metric('dec_l_ae', x=global_iters, y=dec_l_ae_np)
            log_metric('l_reg_zr', x=global_iters, y=l_reg_zr_np)
            log_metric('l_reg_zpp', x=global_iters, y=l_reg_zpp_np)