def train_and_test(flags,
                   corruption_level=0,
                   gold_fraction=0.5,
                   get_C=uniform_mix_C):
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    C = get_C(corruption_level)

    gold, silver = prepare_data(C, gold_fraction)

    print("Gold shape = {}, Silver shape = {}".format(gold.images.shape,
                                                      silver.images.shape))

    # TODO : test on whole set
    test_x = torch.from_numpy(mnist.test.images[:500].reshape([-1, 1, 28, 28]))
    test_y = torch.from_numpy(mnist.test.labels[:500]).type(torch.LongTensor)
    print("Test shape = {}".format(test_x.shape))

    model = LeNet()
    optimizer = torch.optim.Adam([p for p in model.parameters()], lr=0.001)

    for step in range(flags.num_steps):
        x, y = silver.next_batch(flags.batch_size)
        y, y_true = np.array([l[0] for l in y]), np.array([l[1] for l in y])
        x_val, y_val = gold.next_batch(min(flags.batch_size, flags.nval))

        x, y = torch.from_numpy(x.reshape(
            [-1, 1, 28, 28])), torch.from_numpy(y).type(torch.LongTensor)
        x_val, y_val = torch.from_numpy(x_val.reshape(
            [-1, 1, 28, 28])), torch.from_numpy(y_val).type(torch.LongTensor)

        # forward
        if flags.method == "l2w":
            ex_wts = reweight_autodiff(model, x, y, x_val, y_val)
            logits, loss = model.loss(x, y, ex_wts)

            if step % dbg_steps == 0:
                tbrd.log_histogram("ex_wts", ex_wts, step=step)
                tbrd.log_value("More_than_0.01",
                               sum([x > 0.01 for x in ex_wts]),
                               step=step)
                tbrd.log_value("More_than_0.05",
                               sum([x > 0.05 for x in ex_wts]),
                               step=step)
                tbrd.log_value("More_than_0.1",
                               sum([x > 0.1 for x in ex_wts]),
                               step=step)

                mean_on_clean_labels = np.mean(
                    [ex_wts[i] for i in range(len(y)) if y[i] == y_true[i]])
                mean_on_dirty_labels = np.mean(
                    [ex_wts[i] for i in range(len(y)) if y[i] != y_true[i]])
                tbrd.log_value("mean_on_clean_labels",
                               mean_on_clean_labels,
                               step=step)
                tbrd.log_value("mean_on_dirty_labels",
                               mean_on_dirty_labels,
                               step=step)
        else:
            logits, loss = model.loss(x, y)

        print("Loss = {}".format(loss))

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tbrd.log_value("loss", loss, step=step)

        if step % dbg_steps == 0:
            model.eval()

            pred = torch.max(model.forward(test_x), 1)[1]
            test_acc = torch.sum(torch.eq(pred, test_y)).item() / float(
                test_y.shape[0])
            model.train()

            print("Test acc = {}.".format(test_acc))
            tbrd.log_value("test_acc", test_acc, step=step)
dirs = os.listdir("./gcommands/train/")
dirs = sorted(dirs)

parser = argparse.ArgumentParser(
    description='ConvNets for Speech Commands Recognition')

parser.add_argument('--wav_path',
                    default='gcommands/recordings/one/3.wav',
                    help='path to the audio file')

args = parser.parse_args()

path = args.wav_path

warnings.filterwarnings("ignore")

model = LeNet()
model.load_state_dict(torch.load("checkpoint/ckpt.t7"))

model.eval()

wav = spect_loader(path,
                   window_size=.02,
                   window_stride=.01,
                   normalize=True,
                   max_len=101,
                   window='hamming')
#print(wav.shape)
with torch.no_grad():
    label = model.forward(wav.view(1, 1, 161, 101))
print(dirs[np.argmax(np.ravel(label.numpy()))])