コード例 #1
0
ファイル: main.py プロジェクト: SourceCode1037/SUGPool
def run(model):
    min_loss = 1e10
    patience = 0
    for epoch in range(args.epochs):
        print("Epoch{}:".format(epoch))
        model.train()
        for i, data in enumerate(train_loader):
            # print("batch"+str(i))
            data = data.to(args.device)
            out = model(data)
            loss = F.nll_loss(out, data.y)
            # print("Training loss:{}".format(loss.item()))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        val_acc, val_loss = test(model, val_loader)
        print("Validation loss:{}\taccuracy:{}".format(val_loss, val_acc))
        if val_loss < min_loss:
            torch.save(model.state_dict(), 'latest.pth')
            print("Model saved at epoch{}".format(epoch))
            min_loss = val_loss
            patience = 0
        else:
            patience += 1
        if patience > args.patience:
            print("Early stop at epoch{}".format(epoch))
            break

    model = Net(args).to(args.device)
    model.load_state_dict(torch.load('latest.pth'))
    test_acc, test_loss = test(model, test_loader)
    print("Test accuarcy:{}".format(test_acc))
コード例 #2
0
def run_main():
    """
    Will Train a network and save all the batch gradients in a file (pickled dict with :
    key:value => conv1.weight : array [batch x chan x h x w]
    :return:
    """
    gradient_save_path = "gradients/test"
    os.makedirs(gradient_save_path, exist_ok=True)

    device = torch.device("cuda")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=256,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=256,
                                              shuffle=True)

    model = Net(5).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    # save initial weights
    torch.save(model.state_dict(),
               os.path.join(gradient_save_path, "weights.pth"))
    training_curve = []
    for epoch in range(1, 6):
        gradients = record_grad(model,
                                device,
                                train_loader,
                                optimizer,
                                n_steps=300)
        # save epoch gradients
        pickle.dump(
            gradients,
            open(
                os.path.join(gradient_save_path, "epoch_{}.pkl".format(epoch)),
                "bw"))
        # save test results
        loss, accuracy = test(model, device, test_loader)
        training_curve.append([loss, accuracy])
    np.savetxt(os.path.join(gradient_save_path, "train_original.txt"),
               np.array(training_curve))
コード例 #3
0
ファイル: trainer.py プロジェクト: new-stone-object/DNTD
    def __init__(self,
                 dataloader: DataLoader,
                 epoch: int,
                 optim_create_func: Optimizer,
                 lr: float,
                 loss_function: LOSS_FUNC,
                 device: torch.device = torch.device('cpu'),
                 pretrained_model_path: str = None,
                 visual_helper: VisualHelper = None,
                 model_saver: ModelSaver = None,
                 weight_init_func: WEIGHT_INIT_FUNC = None,
                 drop_rate=0,
                 *args,
                 **kwargs):
        super(Trainer, self).__init__(*args, **kwargs)
        cpu_device = torch.device('cpu')
        self.model: Net = Net(drop_rate=drop_rate).to(cpu_device)
        self.model.train()
        self.dataloader = dataloader
        if weight_init_func is not None:
            self.model.apply(weight_init_func)
        if pretrained_model_path is not None:
            ret = self.model.load_encoder_weight(pretrained_model_path)
            print(ret)

        self.epoch = epoch
        optim_create_func(self.model, lr)
        self.optim = optim_create_func
        self.device = device
        self.loss_func = loss_function
        self.visual_helper = visual_helper
        self.model_saver = model_saver
        self.start_epoch = 0
        self.model.to(device)
コード例 #4
0
ファイル: model.py プロジェクト: FreddoIsHere/TradingBot
 def __init__(self, path=current_folder, learning_rate=1e-3, batch_size=128):
     torch.manual_seed(12345)
     self.path = path
     self.batch_size = batch_size
     self.data_creator = DataCreator(self.batch_size)
     self.learning_rate = learning_rate
     try:
         self.net = torch.load(self.path + "/net.pth")
         print("--------------------------------\n"
               "Models were loaded successfully! \n"
               "--------------------------------")
     except:
         print("-----------------------\n"
               "No models were loaded! \n"
               "-----------------------")
         self.net = Net(input_dim=225, hidden_dim=450)
     self.net.cuda()
コード例 #5
0
def run_main():
    """
    Load gradient and update a network weights (in open loop)
    key:value => conv1.weight : array [batch x chan x h x w]
    :return:
    """
    gradient_path = "gradients/test"
    epoch_files = [
        os.path.join(gradient_path, x) for x in os.listdir(gradient_path)
        if "epoch" in x
    ]
    epoch_files.sort()

    gradient_save_path = "gradients/test"
    os.makedirs(gradient_save_path, exist_ok=True)

    device = torch.device("cuda")
    model = Net(5).to(device)

    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=256,
                                              shuffle=True)

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    model.load_state_dict(
        torch.load(os.path.join(gradient_path, "weights.pth")))
    training_curve = []
    for file in epoch_files:
        grads = pkl.load(open(file, "br"))
        train_grad(grads, model, device, optimizer)
        loss, accuracy = test(model, device, test_loader)
        training_curve.append([loss, accuracy])
    np.savetxt(os.path.join(gradient_save_path, "train_loaded.txt"),
               np.array(training_curve))
コード例 #6
0
def run_main():
    """
    Will Train a network and save all the batch gradients in a file (pickled dict with :
    key:value => conv1.weight : array [batch x chan x h x w]
    :return:
    """
    device = torch.device("cuda")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=256,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=512,
                                              shuffle=True)

    model = Net(5).to(device)
    model.train()

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    grad_func = GradientFunc()
    training_curve = []
    for epoch in range(1, 6):
        process(model, device, train_loader, optimizer, grad_func, n_steps=100)
        # save epoch gradients
        # save test results
        loss, accuracy = test(model, device, test_loader)
コード例 #7
0
ファイル: train_mnist.py プロジェクト: soravux/AudioGrad
def run_main():
    device = torch.device("cuda")
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=256, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=256, shuffle=True)


    model = Net(5).to(device)
    summary(model, (1, 28, 28))
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    for epoch in range(1, 6):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
コード例 #8
0
 def __init__(self,
              dataloader: DataLoader,
              img_save_path: str,
              model_pth_path: str,
              device: torch.device = torch.device('cpu'),
              label_trans_func: LABEL_TRANS_FUNC = lambda x: x,
              *args,
              **kwargs) -> None:
     super(Tester, self).__init__(*args, **kwargs)
     self._dataloader = dataloader
     self._img_save_path = img_save_path
     mkdirs(img_save_path)
     self._device = device
     self._label_transform_func = label_trans_func
     model_dict = torch.load(model_pth_path, map_location='cpu')['model']
     self._model = Net()
     self._model.load_state_dict(model_dict)
     del model_dict
     self._model = self._model.to(device)
     self._model.eval()
コード例 #9
0
ファイル: main.py プロジェクト: birdortyedi/machine-learning
                                      # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                      transforms.Normalize((0.1307,), (0.3081,))
                                      ])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.1307,), (0.3081,))
                                     ])

train_dataset = MNIST(root="./mnist", download=True, transform=train_transform)
test_dataset = MNIST(root="./mnist", train=False, download=True, transform=test_transform)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net()
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs...")
    net = nn.DataParallel(net)
net.to(device)

loss_fn = nn.CrossEntropyLoss()
lr = 0.003
optimizer = optim.Adam(net.parameters(), weight_decay=1e-6, lr=lr)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
# writer = SummaryWriter()


def compute_confusion_matrix(init_cm, y_true, y_pred):
    for i in range(len(y_true)):
        init_cm[y_true[i]][y_pred[i]] += 1
コード例 #10
0
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size = batch_size,
                                           shuffle = True,
                                           )
val_set = datasets.ImageFolder('../surrogate_dataset/unlab_dataset_035/val_set/',
                                  transform = transforms.Compose([transforms.ToTensor(), normalize]))
val_loader = torch.utils.data.DataLoader(val_set,
                                           batch_size = batch_size,
                                           shuffle = False,
                                           )

nb_classes = len(os.listdir('../surrogate_dataset/unlab_dataset_035/train_set/'))
print "Training with " + str(nb_classes) + " classes"

# define a CNN
net = Net(nb_classes).cuda()
print "Model defined"
print "Model to GPU"

initial_lr = 0.01

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = initial_lr, momentum = 0.9, weight_decay = 1e-4)
lr = optimizer.param_groups[0]['lr']
print "Initial learning rate: " + str(lr)

# Start training
loss_history = []
accuracy_val_history = []
accuracy_train_history = []
val_loss_history = []
コード例 #11
0
ファイル: main.py プロジェクト: wayneowen7/GSAPool

#save result in txt
def save_result(test_acc, save_path):
    with open(os.path.join(save_path, 'result.txt'), 'a') as f:
        test_acc *= 100
        f.write(args.dataset + ";")
        f.write("pooling_layer_type:" + args.pooling_layer_type + ";")
        f.write("feature_fusion_type:" + args.feature_fusion_type + ";")
        f.write(str(test_acc))
        f.write('\r\n')


#training configuration
train_loader, val_loader, test_loader = data_builder(args)
model = Net(args).to(args.device)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)

#training steps
patience = 0
min_loss = args.min_loss
for epoch in range(args.epochs):
    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(args.device)
        out = model(data)
        loss = F.nll_loss(out, data.y)
        print("Training loss:{}".format(loss.item()))
        loss.backward()
コード例 #12
0
ファイル: convert.py プロジェクト: doreenfan/ml-reactions
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn

import sys
sys.path.insert(-1,'../../maestroflame')

from networks import Net, OC_Net

## Read Model file

filename = "example_model.pt"

model = Net(16, 64, 128, 64, 14)

model.load_state_dict(torch.load(filename))
# model.eval()

print(model)

## Converting to Torch Script (Annotation)

# Using annotation
net_module = torch.jit.script(model)
net_module.save("ts_model.pt")
print(net_module.code)
コード例 #13
0
ファイル: trainer.py プロジェクト: pranoyr/vehicle-reid
		transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
			0.229, 0.224, 0.225])
	])

	training_data = TripletVeriDataset(
		root_dir=opt.train_images, xml_path=opt.train_annotation_path, transform=train_transform)
	validation_data = TripletVeriDataset(
		root_dir=opt.test_images, xml_path=opt.test_annotation_path, transform=test_transform)

	train_loader = torch.utils.data.DataLoader(training_data,
											   batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)

	val_loader = torch.utils.data.DataLoader(validation_data,
											 batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)

	embedding_net = Net()
	# embedding_net = MobileNetv2()
	model = TripletNet(embedding_net).to(device)
	loss_fn = nn.TripletMarginLoss(margin=0.5)
	# loss_fn = TripletLoss(0.5)

	# optimizer = optim.Adadelta(
	# 	model.parameters(), lr=opt.learning_rate, weight_decay=5e-4)
	optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
	# scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
	scheduler = ReduceLROnPlateau(
			optimizer, 'min', patience=5)
	

	if opt.resume_path:
		print('loading checkpoint {}'.format(opt.resume_path))
コード例 #14
0
ファイル: main.py プロジェクト: SourceCode1037/SUGPool
            index_test = torch.from_numpy(index_test).to(torch.long)
            train_validation = dataset[index_train_validation]
            test_set = dataset[index_test]
            num_training = int(len(train_validation) * 0.9)
            num_val = len(train_validation) - num_training

            training_set, validation_set = random_split(
                train_validation, [num_training, num_val])
            train_loader = DataLoader(training_set,
                                      batch_size=args.batch_size,
                                      shuffle=True)
            val_loader = DataLoader(validation_set,
                                    batch_size=args.batch_size,
                                    shuffle=False)
            test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
            model = Net(args).to(args.device)
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)
            run(model)
    else:  # one random fold validation
        num_training = int(len(dataset) * 0.9)
        num_test = len(dataset) - (num_training)
        training_set, test_set = random_split(dataset,
                                              [num_training, num_test])

        train_loader = DataLoader(training_set,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        test_loader = DataLoader(test_set,
                                 batch_size=args.batch_size,
コード例 #15
0
ファイル: models.py プロジェクト: Pandade1997/DDAEC_min_fbank
    def test(self, args):
        with open(args.test_list, 'r') as test_list_file:
            self.test_list = [line.strip() for line in test_list_file.readlines()]
        self.model_name = args.model_name
        self.model_file = args.model_file
        self.test_mixture_path = args.test_mixture_path
        self.prediction_path = args.prediction_path

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        # net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)
        # loss and optimizer
        criterion = mse_loss()
        net.eval()
        print('Load model from "%s"' % self.model_file)
        checkpoint = Checkpoint()
        checkpoint.load(self.model_file)
        net.load_state_dict(checkpoint.state_dict)
        with torch.no_grad():
            for i in range(len(self.test_list)):
                # read the mixture for resynthesis
                filename_input = self.test_list[i].split('/')[-1]
                start1 = timeit.default_timer()
                print('{}/{}, Started working on {}.'.format(i + 1, len(self.test_list), self.test_list[i]))
                print('')
                filename_mix = filename_input.replace('.samp', '_mix.dat')

                filename_s_ideal = filename_input.replace('.samp', '_s_ideal.dat')
                filename_s_est = filename_input.replace('.samp', '_s_est.dat')
                # print(filename_mix)
                # sys.exit()
                f_mix = h5py.File(os.path.join(self.test_mixture_path, filename_mix), 'r')
                f_s_ideal = h5py.File(os.path.join(self.prediction_path, filename_s_ideal), 'w')
                f_s_est = h5py.File(os.path.join(self.prediction_path, filename_s_est), 'w')
                # create a test dataset
                testSet = EvalDataset(os.path.join(self.test_mixture_path, self.test_list[i]),
                                      self.num_test_sentences)

                # create a data loader for test
                test_loader = DataLoader(testSet,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=2,
                                         collate_fn=EvalCollate())

                # print '\n[%d/%d] Predict on %s' % (i+1, len(self.test_list), self.test_list[i])

                accu_test_loss = 0.0
                accu_test_nframes = 0

                ttime = 0.
                mtime = 0.
                cnt = 0.
                for k, (mix_raw, cln_raw) in enumerate(test_loader):
                    start = timeit.default_timer()
                    est_s = self.eval_forward(mix_raw, net)
                    est_s = est_s[:mix_raw.size]
                    mix = f_mix[str(k)][:]

                    ideal_s = cln_raw

                    f_s_ideal.create_dataset(str(k), data=ideal_s.astype(np.float32), chunks=True)
                    f_s_est.create_dataset(str(k), data=est_s.astype(np.float32), chunks=True)
                    # compute eval_loss

                    test_loss = np.mean((est_s - ideal_s) ** 2)

                    accu_test_loss += test_loss
                    cnt += 1
                    end = timeit.default_timer()
                    curr_time = end - start
                    ttime += curr_time
                    mtime = ttime / cnt
                    mtime = (mtime * (k) + (end - start)) / (k + 1)
                    print('{}/{}, test_loss = {:.4f}, time/utterance = {:.4f}, '
                          'mtime/utternace = {:.4f}'.format(k + 1, self.num_test_sentences, test_loss, curr_time,
                                                            mtime))

                avg_test_loss = accu_test_loss / cnt
                # bar.update(k,test_loss=avg_test_loss)
                # bar.finish()
                end1 = timeit.default_timer()
                print('********** Finisehe working on {}. time taken = {:.4f} **********'.format(filename_input,
                                                                                                 end1 - start1))
                print('')
                f_mix.close()
                f_s_est.close()
                f_s_ideal.close()
コード例 #16
0
ファイル: model.py プロジェクト: FreddoIsHere/TradingBot
class Model:
    def __init__(self, path=current_folder, learning_rate=1e-3, batch_size=128):
        torch.manual_seed(12345)
        self.path = path
        self.batch_size = batch_size
        self.data_creator = DataCreator(self.batch_size)
        self.learning_rate = learning_rate
        try:
            self.net = torch.load(self.path + "/net.pth")
            print("--------------------------------\n"
                  "Models were loaded successfully! \n"
                  "--------------------------------")
        except:
            print("-----------------------\n"
                  "No models were loaded! \n"
                  "-----------------------")
            self.net = Net(input_dim=225, hidden_dim=450)
        self.net.cuda()

    def predict_signal(self, ticker):
        signals = ['SELL', 'BUY', 'HOLD']
        _, data = get_daily_data(ticker, compact=True)
        self.net.train(False)
        with torch.no_grad():
            input = torch.tensor(data.to_numpy()[-1]).float().cuda()
            output = F.softmax(self.net(input), dim=-1).cpu().numpy()
            signal_idx = np.argmax(output)
        return signals[int(signal_idx)], 100*output[signal_idx]

    def test(self):

        losses = []
        accuracies = []
        buy_accuracies = []
        sell_accuracies = []
        hold_accuracies = []
        data_loader = self.data_creator.provide_testing_stock()
        criterion = nn.CrossEntropyLoss()
        self.net.train(False)
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(data_loader):
                batch_x = batch_x.float().cuda()
                batch_y = batch_y.long().cuda()

                output = self.net(batch_x)
                loss = criterion(output, batch_y)

                output_metric = np.argmax(F.softmax(output, dim=1).cpu().numpy(), axis=1)
                batch_size = batch_y.size()[0]
                batch_y = batch_y.cpu().numpy()
                sell_mask_label = batch_y == 0
                sell_mask_output = output_metric == 0
                sell_accuracies.append(100*(sell_mask_label == sell_mask_output).sum()/batch_size)
                buy_mask_label = batch_y == 1
                buy_mask_output = output_metric == 1
                buy_accuracies.append(100*(buy_mask_label == buy_mask_output).sum()/batch_size)
                hold_mask_label = batch_y == 2
                hold_mask_output = output_metric == 2
                hold_accuracies.append(100*(hold_mask_label == hold_mask_output).sum()/batch_size)
                losses.append((loss.item()))
                accuracy = 100 * sum(1 if output_metric[k] == batch_y[k] else 0 for k in
                                     range(batch_size)) / batch_size
                accuracies.append(accuracy)
        print("Average loss: ", np.mean(losses))
        print("Average accuracy: ", np.mean(accuracies))
        print("Buy-Average accuracy: ", np.mean(buy_accuracies))
        print("Sell-Average accuracy: ", np.mean(sell_accuracies))
        print("Hold-Average accuracy: ", np.mean(hold_accuracies))

    def train(self, epochs):

        rocs_aucs = []
        baseline_rocs_aucs = []
        losses = []
        accuracies = []
        data_loader, class_weights = self.data_creator.provide_training_stock()
        criterion = nn.CrossEntropyLoss()
        optimiser = optim.AdamW(self.net.parameters(), lr=self.learning_rate, weight_decay=1e-5, amsgrad=True)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, patience=220, min_lr=1e-9)
        self.net.train(True)
        pbar = tqdm(total=epochs)

        # train the network
        for epoch in range(epochs):

            for i, (batch_x, batch_y) in enumerate(data_loader):
                batch_x = batch_x.float().cuda()
                batch_y = batch_y.long().cuda()

                self.net.zero_grad()
                output = self.net(batch_x)
                loss = criterion(output, batch_y)
                loss.backward()
                optimiser.step()

                scheduler.step(loss.item())

                # Print some loss stats
                if i % 2 == 0:
                    output_metric = F.softmax(output.detach().cpu(), dim=1).numpy()
                    random_metric = relabel_data(np.random.choice([0, 1, 2], size=(1, self.batch_size), p=[1/3, 1/3, 1/3]))
                    label_metric = relabel_data(batch_y.detach().cpu().numpy())
                    losses.append((loss.item()))
                    rocs_aucs.append(roc_auc_score(label_metric, output_metric, multi_class='ovo'))
                    baseline_rocs_aucs.append(roc_auc_score(label_metric, random_metric, multi_class='ovo'))
                    accuracy = 100 * sum(1 if np.argmax(output_metric[k]) == np.argmax(label_metric[k]) else 0 for k in
                                         range(self.batch_size)) / self.batch_size
                    accuracies.append(accuracy)
            pbar.update(1)
        pbar.close()
        fig, axs = plt.subplots(1, 3)
        axs[0].plot(np.convolve(losses, (1/25)*np.ones(25), mode='valid'))
        axs[1].plot(np.convolve(rocs_aucs, (1/25)*np.ones(25), mode='valid'))
        axs[1].plot(np.convolve(baseline_rocs_aucs, (1/25)*np.ones(25), mode='valid'))
        axs[1].legend(['Net', 'Baseline'])
        axs[2].plot(np.convolve(accuracies, (1/25)*np.ones(25), mode='valid'))
        plt.show()

    def save(self):
        torch.save(self.net, self.path + "/net.pth")
コード例 #17
0
test_min_dir = '/scratch/additya/graph_data/DB2_A_cleaned_processed_min'
# # im_dir = test_im_dir
# # min_dir = test_min_dir

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
device_ids = [int(device_id) for device_id in device_ids]
print('GPU configuration:', device, device_ids)

# # device = torch.device('cuda')
# # encoder = NodeEncoder().to(device)
encoder = resnet18(pretrained=True)
encoder.fc = torch.nn.Linear(512, 512)
propagator = GraphPropagator().to(device)
gencoder = GraphEncoder().to(device)
model = Net(encoder, propagator, gencoder)
# model = torch.nn.DataParallel(model, device_ids=device_ids).to(device)
# clf = torch.nn.Linear(256, 2000)
# clf = torch.nn.DataParallel(clf, device_ids=device_ids).to(device)
model = model.to(device)
clf = torch.nn.Linear(256, 2000).to(device)

criterion = torch.nn.CrossEntropyLoss()
# # criterion = OnlineTripletLoss(2.0, SemihardNegativeTripletSelector(2.0))

# print("Net done!")

train_dataset = MinutiaeDataset(im_dir, min_dir, 'train')
# # sampler = BalancedBatchSampler(dataset.labels, n_classes=10, n_samples=5)
# # train_loader = DataLoader(dataset, num_workers=10, batch_sampler=sampler)
train_loader = DataLoader(train_dataset,
コード例 #18
0
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [
                line.strip() for line in train_list_file.readlines()
            ]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file, self.num_test_sentences)
        #trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())
        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        #net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001
                                                                      ] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes) in enumerate(train_loader):
                i += start_iter
                features, labels = features.to(self.device), labels.to(
                    self.device)

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)
                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)

                loss = 0.8 * loss1 + 0.2 * loss2
                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'
                    .format(i + 1, num_train_batches, epoch + 1,
                            self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader)

                    net.train()

                    print(
                        'Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )'
                        % (epoch + 1, self.max_epoch, i + 1,
                           self.num_train_sentences // self.batch_size,
                           avg_train_loss, avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss,
                                            avg_eval_loss, best_loss,
                                            net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best,
                                    os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt',
                            checkpoint, self.eval_steps)
                    #metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

            avg_st, avg_sn, avg_pe = self.validate_with_metrics(
                net, eval_loader)
            net.train()
            print('#' * 50)
            print('')
            print(
                'After {} epoch the performance on validation score is a s follows:'
                .format(epoch + 1))
            print('')
            print('STOI: {:.4f}'.format(avg_st))
            print('SNR: {:.4f}'.format(avg_sn))
            print('PESQ: {:.4f}'.format(avg_pe))
            for param_group in optimizer.param_groups:
                print('learning_rate', param_group['lr'])
            print('')
            print('#' * 50)
            checkpoint = Checkpoint(epoch, 0, None, None, best_loss,
                                    net.state_dict(), optimizer.state_dict())
            checkpoint.save(
                False,
                os.path.join(self.model_path,
                             self.model_name + '-{}.model'.format(epoch + 1)),
                os.path.join(self.model_path, best_model))
            metric_logging(self.log_path, self.model_name + '_metric_log.txt',
                           epoch, [avg_st, avg_sn, avg_pe])
            start_iter = 0.
コード例 #19
0
ファイル: models.py プロジェクト: Pandade1997/DDAEC_min_fbank
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [line.strip() for line in train_list_file.readlines()]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.train_wav_path = args.train_wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file,
                              self.num_test_sentences)
        # trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())

        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        # net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        iteration = 0
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes, feat_size, label_size, get_filename) in enumerate(
                    train_loader):  # features:torch.Size([4, 1, 250, 512])
                iteration += 1
                labels_cpu = labels
                i += start_iter
                features, labels = features.to(self.device), labels.to(self.device)  # torch.Size([4, 1, 250, 512])

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)  # torch.Size([4, 1, 64256])

                feature_maker = Fbank(sample_rate=16000, n_fft=400, n_mels=40)
                loss_fbank = 0

                for t in range(len(get_filename)):
                    reader = h5py.File(get_filename[t], 'r')
                    feature_asr = reader['noisy_raw'][:]
                    label_asr = reader['clean_raw'][:]

                    feat_asr_size = int(feat_size[t][0].item())
                    label_asr_size = int(label_size[t][0].item())

                    output_asr = self.train_asr_forward(feature_asr, net)
                    est_output_asr = output_asr[:feat_asr_size]
                    ideal_labels_asr = label_asr

                    # 保存train的wav
                    est_path = os.path.join(self.train_wav_path, '{}_est.wav'.format(t + 1))
                    ideal_path = os.path.join(self.train_wav_path, '{}_ideal.wav'.format(t + 1))
                    sf.write(est_path, normalize_wav(est_output_asr)[0], self.srate)
                    sf.write(ideal_path, normalize_wav(ideal_labels_asr)[0], self.srate)

                    # read wav
                    est_sig = sb.dataio.dataio.read_audio(est_path).unsqueeze(axis=0).to(self.device)
                    ideal_sig = sb.dataio.dataio.read_audio(ideal_path).unsqueeze(axis=0).to(self.device)
                    est_sig_feats = feature_maker(est_sig)
                    ideal_sig_feats = feature_maker(ideal_sig)

                    # fbank_loss
                    loss_fbank += F.mse_loss(est_sig_feats, ideal_sig_feats, True)

                loss_fbank /= 100 * len(get_filename)
                # print(loss_fbank)
                # loss_fbank = 1 / (1 + math.exp(loss_fbank))

                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)
                # print(loss1)
                # print(loss2)

                # loss = 0.8 * loss1 + 0.2 * loss2
                loss = 0.4 * loss1 + 0.1 * loss2 + 0.5 * loss_fbank

                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                # train-loss show
                summary.add_scalar('Train Loss', accu_train_loss, iteration)

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss_fbank, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'.format(
                        i + 1,
                        num_train_batches, epoch + 1, self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader, iteration)

                    net.train()

                    print('Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )' % (
                        epoch + 1, self.max_epoch, i + 1, self.num_train_sentences // self.batch_size,
                        avg_train_loss,
                        avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss, avg_eval_loss, best_loss, net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best, os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt', checkpoint, self.eval_steps)
                    # metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

        avg_st, avg_sn, avg_pe = self.validate_with_metrics(net, eval_loader)
        net.train()
        print('#' * 50)
        print('')
        print('After {} epoch the performance on validation score is a s follows:'.format(epoch + 1))
        print('')
        print('STOI: {:.4f}'.format(avg_st))
        print('SNR: {:.4f}'.format(avg_sn))
        print('PESQ: {:.4f}'.format(avg_pe))
        for param_group in optimizer.param_groups:
            print('learning_rate', param_group['lr'])
        print('')
        print('#' * 50)
        checkpoint = Checkpoint(epoch, 0, None, None, best_loss, net.state_dict(), optimizer.state_dict())
        checkpoint.save(False, os.path.join(self.model_path, self.model_name + '-{}.model'.format(epoch + 1)),
                        os.path.join(self.model_path, best_model))
        metric_logging(self.log_path, self.model_name + '_metric_log.txt', epoch, [avg_st, avg_sn, avg_pe])
        start_iter = 0.
コード例 #20
0
ファイル: main.py プロジェクト: ShuGuoJ/DiffPool
num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - (num_training + num_val)
training_set, validation_set, test_set = random_split(
    dataset, [num_training, num_val, num_test])

train_loader = DataLoader(training_set,
                          batch_size=args.batch_size,
                          shuffle=True)
val_loader = DataLoader(validation_set,
                        batch_size=args.batch_size,
                        shuffle=False)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
max_nodes = count_max_nodes(dataset)
model = Net(dataset.num_features, 64, dataset.num_classes, max_nodes)
trainer = Trainer(model)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)
criterion = torch.nn.CrossEntropyLoss()

# def eval(model,loader):
#     model.eval()
#     correct = 0.
#     loss = 0.
#     for data in loader:
#         data = data.to(args.device)
#         out = model(data)
#         pred = out.max(dim=1)[1]
#         correct += pred.eq(data.y).sum().item()
コード例 #21
0
ファイル: main.py プロジェクト: WissingChen/DeepRadiomics
def train(ini_file):
    ''' Performs training according to .ini file

    :param ini_file: (String) the path of .ini file
    :return best_c_index: the best c-index
    '''
    # reads configuration from .ini file
    config = read_config(ini_file)
    # builds network|criterion|optimizer based on configuration
    model = Net(config['network']).to(device)
    criterion = Criterion(config['network'], device).to(device)
    optimizer = eval('optim.{}'.format(config['train']['optimizer']))(
        model.parameters(), lr=config['train']['learning_rate'])
    # constructs data loaders based on configuration
    train_dataset = MakeDataset(config['train']['h5_file'], is_train=True, device=device)
    test_dataset = MakeDataset(config['train']['h5_file'], is_train=False, device=device)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_dataset.__len__())
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=test_dataset.__len__())
    # training
    _best_acc = 0.70
    best_acc = 0.65
    best_ep = 0
    flag = 0
    _best_auc = 0
    best_auc = 0
    best_roc = None
    for epoch in range(1, config['train']['epochs'] + 1):
        # adjusts learning rate
        lr = adjust_learning_rate(optimizer, epoch,
                                  config['train']['learning_rate'],
                                  config['train']['lr_decay_rate'])
        # train step
        model.train()
        for X, y in train_loader:
            # makes predictions
            pred = model(X)
            train_loss = criterion(pred, y, model)
            train_FPR, train_TPR, train_ACC, train_roc, train_roc_auc, _, _, _, _ = Auc(pred, y)
            # updates parameters
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        # valid step
        model.eval()
        for X, y in test_loader:
            # makes predictions
            with torch.no_grad():
                pred = model(X)
                # print(pred, y)
                valid_loss = criterion(pred, y, model)
                valid_FPR, valid_TPR, valid_ACC, valid_roc, valid_roc_auc, _, _, _, _ = Auc(pred, y)
                if valid_ACC > best_acc and train_ACC > _best_acc:
                    flag = 0
                    best_acc = valid_ACC
                    _best_acc = train_ACC
                    best_ep = epoch
                    best_auc = valid_roc_auc
                    _best_auc = train_roc_auc
                    best_roc = valid_roc
                    # saves the best model
                    torch.save({
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': epoch}, os.path.join(models_dir, ini_file.split('\\')[-1] + '.pth'))
                else:
                    flag += 1
                    if flag >= patience:
                        print('epoch: {}\t{:.8f}({:.8f})'.format(best_ep, _best_acc, best_acc))
                        if best_roc is not None:
                            plt.plot(best_roc[:, 0], best_roc[:, 1])
                            plt.title('ep:{}  AUC: {:.4f}({:.4f}) ACC: {:.4f}({:.4f})'.format(best_ep, _best_auc, best_auc, _best_acc, best_acc))
                            plt.show()
                        return best_acc, _best_acc
        # notes that, train loader and valid loader both have one batch!!!
        print('\rEpoch: {}\tLoss: {:.8f}({:.8f})\tACC: {:.8f}({:.8f})\tAUC: {}({})\tFPR: {:.8f}({:.8f})\tTPR: {:.8f}({:.8f})\tlr: {:g}\n'.format(
            epoch, train_loss.item(), valid_loss.item(), train_ACC, valid_ACC, train_roc_auc, valid_roc_auc, train_FPR, valid_FPR, train_TPR, valid_TPR, lr), end='', flush=False)
    return best_acc, _best_acc
コード例 #22
0
ファイル: train.py プロジェクト: xia0long/drl-flappybird
def train(args):
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Network()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.MSELoss()
    fb = FlappyBird()

    image, reward, terminal, score = fb.next_frame(randint(0, 1))[:-1]

    img = image_pre_processing(image, args.img_size, args.img_size)
    img = torch.from_numpy(img)
    img.to(device)
    state = torch.cat(tuple(img for _ in range(4)))[None, :, :, :]

    replay_memory = []
    episode = 0

    while 1:
        prediction = model(state)[0]
        epsilon = args.initial_epsilon
        if np.random.uniform() > epsilon:
            action = torch.argmax(prediction)
        else:
            action = randint(0, 1)

        next_image, reward, terminal = fb.next_frame(action, '')[:3]

        if terminal:
            fb.__init__()
            next_image, reward, terminal = fb.next_frame(action, '')[:3]

        next_img = image_pre_processing(next_image, args.img_size,
                                        args.img_size)
        next_img = torch.from_numpy(next_img)
        next_state = torch.cat((state[0, 1:, :, :], next_img))[None, :, :, :]
        replay_memory.append([state, action, reward, next_state, terminal])
        if len(replay_memory) > args.replay_memory_size:
            del replay_memory[0]

        batch = sample(replay_memory, min(len(replay_memory), args.batch_size))
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(
            *batch)

        state_batch = torch.cat(tuple(state for state in state_batch))
        action_batch = torch.from_numpy(
            np.array([[1, 0] if action == 0 else [0, 1]
                      for action in action_batch],
                     dtype=np.float32))
        reward_batch = torch.from_numpy(
            np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = torch.cat(tuple(state
                                           for state in next_state_batch))

        state_batch.to(device)
        action_batch.to(device)
        reward_batch.to(device)
        next_state_batch.to(device)

        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)

        y_batch = torch.cat(
            tuple(reward if terminal else reward + args.gamma * torch.max(prediction) for reward, terminal, prediction in \
                zip(reward_batch, terminal_batch, next_prediction_batch))
        )
        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
        optimizer.zero_grad()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()

        state = next_state
        episode += 1