Beispiel #1
0
def main(_):  #?
    with tf.Session() as sess:
        vdsr = VDSR(sess,
                    image_size=FLAGS.image_size,
                    label_size=FLAGS.label_size,
                    layer=FLAGS.layer,
                    c_dim=FLAGS.c_dim)

        vdsr.train(FLAGS)
Beispiel #2
0
def main():

    if not os.path.exists(Config.checkpoint_dir):
        os.makedirs(Config.checkpoint_dir)

    with tf.Session() as sess:
        trysr = VDSR(sess,
                     image_size=Config.image_size,
                     label_size=Config.label_size,
                     batch_size=Config.batch_size,
                     c_dim=Config.c_dim,
                     checkpoint_dir=Config.checkpoint_dir,
                     scale=Config.scale)

        trysr.train(Config)
Beispiel #3
0
 def main(self):
     global model
     print("VDSR ==> Data loading .. ")
     loader = data.Data(self.args)
     print("VDSR ==> Check run type .. ")
     if self.args.run_type == 'train':
         train_data_loader = loader.loader_train
         test_data_loader = loader.loader_test
         print("VDSR ==> Load model .. ")
         model = VDSR.VDSR()
         print("VDSR ==> Setting optimizer .. [ ", self.args.optimizer,
               " ] , lr [ ", self.args.lr, " ] , Loss [ MSE ]")
         optimizer = optim.Adam(model.parameters(), self.args.lr)
         if self.args.cuda:
             model.cuda()
         self.train(model, optimizer, self.args.epochs, train_data_loader,
                    test_data_loader)
     elif self.args.run_type == 'test':
         print("VDSR ==> Testing .. ")
         if os.path.exists(self.args.pre_model_dir):
             if not os.path.exists(self.args.dir_data_test_lr):
                 print("VDSR ==> Fail [ Test model is not exists ]")
             else:
                 test_data_loader = loader.loader_test
                 Loaded = torch.load(self.args.pre_model_dir)
                 model.load_state_dict(Loaded)
                 if self.args.cuda:
                     model.cuda()
                 self.test(self.args, test_data_loader, model)
         else:
             print(
                 "VDSR ==> Fail [ Pretrain model directory is not exists ]")
Beispiel #4
0
def main(_): 
    with tf.Session() as sess:
        vdsr = VDSR(sess,
                      image_size = FLAGS.image_size,
                      label_size = FLAGS.label_size,
                      layer = FLAGS.layer,
                      c_dim = FLAGS.c_dim)
	if FLAGS.is_train:
           vdsr.train(FLAGS)
	else:
	   FLAGS.c_dim = 3
	   vdsr.test(FLAGS)
Beispiel #5
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument("--batchSize", type=int, default=128, help="Training batch size. Default 128")
    parser.add_argument("--Epochs", type=int, default=50, help="Number of epochs to train for")
    parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1")
    parser.add_argument("--step", type=int, default=10,
                        help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10")
    parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number")
    parser.add_argument("--cuda", action="store_true", help="Use cuda?")
    parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4")
    parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1")
    parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
    parser.add_argument("--weight-decay", default=1e-4, type=float, help="Weight decay, Default: 1e-4")
    parser.add_argument("--pretrained", default='', type=str, help="Path to pretrained model")
    parser.add_argument("--train_data", required=True, type=str, help="Path to preprocessed train dataset")
    parser.add_argument("--test_data", default="./assets/", type=str, help="Path to file containing test images")
    args = parser.parse_args()

    cuda = args.cuda
    if cuda:
        print("=> use gpu id: '{}'".format(0))
        os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

    cudnn.benchmark = True

    train_set = prepareDataset("data/train.h5")
    train_data = DataLoader(dataset=train_set, num_workers=args.threads, batch_size=args.batchSize, shuffle=True)

    model = VDSR()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    criterion = nn.MSELoss(size_average=False)
    if cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    if args.pretrained:
        if os.path.isfile(args.pretrained):
            print("=> loading model '{}'".format(args.pretrained))
            checkpoint = torch.load(args.pretrained)
            args.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint['model'].state_dict())
        else:
            print("No model found at '{}'".format(opt.pretrained))

    train(args.start_epoch, train_data, optimizer, model, criterion, args.Epochs, args)
    eval(model, args)
    def test(self, mode, inference):
        # images = low resolution, labels = high resolution
        sess = self.sess

        # for training a particular image(one image)
        test_label_list = sorted(glob.glob('./dataset/test/gray/*.*'))

        num_image = len(test_label_list)

        assert mode == 'SRCNN' or mode == 'VDSR'
        if mode == 'SRCNN':
            sr_model = SRCNN(channel_length=self.c_length, image=self.x)
            _, _, prediction = sr_model.build_model()
        elif mode == 'VDSR':
            sr_model = VDSR(channel_length=self.c_length, image=self.x)
            prediction, residual, _ = sr_model.build_model()

        with tf.name_scope("PSNR"):
            psnr = 10 * tf.log(255 * 255 * tf.reciprocal(
                tf.reduce_mean(tf.square(self.y - prediction)))) / tf.log(
                    tf.constant(10, dtype='float32'))

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver()

        saver.restore(sess, self.save_path)

        for j in range(2, 5):
            avg_psnr = 0
            for i in range(num_image):
                test_image_list = sorted(
                    glob.glob('./dataset/test/X{}/*.*'.format(j)))
                test_image = np.array(Image.open(test_image_list[i]))
                test_image = test_image[np.newaxis, :, :, np.newaxis]
                test_label = np.array(Image.open(test_label_list[i]))
                h = test_label.shape[0]
                w = test_label.shape[1]
                h -= h % j
                w -= w % j
                test_label = test_label[np.newaxis, 0:h, 0:w, np.newaxis]
                # print(test_image.shape, test_label.shape)

                final_psnr = sess.run(psnr,
                                      feed_dict={
                                          self.x: test_image,
                                          self.y: test_label
                                      })

                print('X{} : Test PSNR is '.format(j), final_psnr)
                avg_psnr += final_psnr

                if inference:
                    pred = sess.run(prediction,
                                    feed_dict={
                                        self.x: test_image,
                                        self.y: test_label
                                    })
                    pred = np.squeeze(pred).astype(dtype='uint8')
                    pred_image = Image.fromarray(pred)
                    filename = './restored_{0}/{3}/{1}_X{2}.png'.format(
                        mode, i, j, self.date)
                    pred_image.save(filename)
                    if mode == 'VDSR':
                        res = sess.run(residual,
                                       feed_dict={
                                           self.x: test_image,
                                           self.y: test_label
                                       })
                        res = np.squeeze(res).astype(dtype='uint8')
                        res_image = Image.fromarray(res)
                        filename = './restored_{0}/{3}/{1}_X{2}_res.png'.format(
                            mode, i, j, self.date)
                        res_image.save(filename)

            print('X{} : Avg PSNR is '.format(j), avg_psnr / 5)
Beispiel #7
0
def main():
    global opt
    opt = parser.parse_args()
    opt.gpuids = list(map(int, opt.gpuids))

    print(opt)

    if opt.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    cudnn.benchmark = True

    if not opt.test:
        train_set = get_training_set(opt.dataset, opt.crop_size,
                                     opt.upscale_factor, opt.add_noise,
                                     opt.noise_std)
        validation_set = get_validation_set(opt.dataset, opt.crop_size,
                                            opt.upscale_factor)

    # test_set = get_test_set(
    #     opt.dataset, opt.crop_size, opt.upscale_factor)

    if not opt.test:
        training_data_loader = DataLoader(dataset=train_set,
                                          num_workers=opt.threads,
                                          batch_size=opt.batch_size,
                                          shuffle=True)
        validating_data_loader = DataLoader(dataset=validation_set,
                                            num_workers=opt.threads,
                                            batch_size=opt.test_batch_size,
                                            shuffle=False)

    # testing_data_loader = DataLoader(
    #     dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)

    model = VDSR()
    criterion = nn.MSELoss()

    if opt.cuda:
        torch.cuda.set_device(opt.gpuids[0])
        with torch.cuda.device(opt.gpuids[0]):
            model = model.cuda()
            criterion = criterion.cuda()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.lr,
                           weight_decay=opt.weight_decay)
    #     optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)

    # if opt.test:
    #     model_name = join("model", opt.model)
    #     model = torch.load(model_name)
    #     start_time = time.time()
    #     test(model, criterion, testing_data_loader)
    #     elapsed_time = time.time() - start_time
    #     print("===> average {:.2f} image/sec for test".format(
    #         100.0/elapsed_time))
    #     return

    train_time = 0.0
    validate_time = 0.0
    for epoch in range(1, opt.epochs + 1):
        start_time = time.time()
        train(model, criterion, epoch, optimizer, training_data_loader)
        elapsed_time = time.time() - start_time
        train_time += elapsed_time
        #         print("===> {:.2f} seconds to train this epoch".format(
        #             elapsed_time))
        start_time = time.time()
        validate(model, criterion, validating_data_loader)
        elapsed_time = time.time() - start_time
        validate_time += elapsed_time
        #         print("===> {:.2f} seconds to validate this epoch".format(
        #             elapsed_time))
        if epoch % 10 == 0:
            checkpoint(model, epoch)

    print("===> average training time per epoch: {:.2f} seconds".format(
        train_time / opt.epochs))
    print("===> average validation time per epoch: {:.2f} seconds".format(
        validate_time / opt.epochs))
    print("===> training time: {:.2f} seconds".format(train_time))
    print("===> validation time: {:.2f} seconds".format(validate_time))
    print("===> total training time: {:.2f} seconds".format(train_time +
                                                            validate_time))
Beispiel #8
0
from model import VDSR
from utils import *

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights_file', type=str, required=True)
    parser.add_argument('--image_file', type=str, required=True)
    parser.add_argument('--realimage_file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=1)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = VDSR().to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    # model, optim = torch.load(model.state_dict(), os.path.join(args.weights_file, 'epoch_150.pth'))

    model.eval()

    image = Image.open(args.realimage_file).convert('RGB')
    resample = Image.open(args.image_file).convert('RGB')
Beispiel #9
0
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = DEVICE_GPU_ID


def restore_session_from_checkpoint(sess, saver):
    checkpoint = tf.train.latest_checkpoint(TRAIN_DIR)
    if checkpoint:
        saver.restore(sess, checkpoint)
        return True
    else:
        return False


if MODEL == 'VDSR':
    model = VDSR(scale=SCALE)
else:
    model = EDSR(scale=SCALE)

data_loader = DataLoader(data_dir=TRAIN_PNG_PATH,
                         batch_size=BATCH_SIZE,
                         shuffle_num=SHUFFLE_NUM,
                         prefetch_num=PREFETCH_NUM,
                         scale=SCALE)

if DATA_LOADER_MODE == 'TFRECORD':
    if len(os.listdir(TRAIN_TFRECORD_PATH)) == 0:
        data_loader.gen_tfrecords(TRAIN_TFRECORD_PATH)
    lrs, bics, gts = data_loader.read_tfrecords(TRAIN_TFRECORD_PATH)
else:
    lrs, bics, gts = data_loader.read_pngs()
Beispiel #10
0
                        default=1e-4,
                        type=float,
                        help="Weight decay, Default: 1e-4")
    args = parser.parse_args()

    args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    model = VDSR().to(device)

    criterion = nn.MSELoss()

    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
Beispiel #11
0
    def train_vdsr(self, iteration):
        # images = low resolution, labels = high resolution
        sess = self.sess
        #load data
        train_image_list_x2 = sorted(glob.glob('./dataset/training/X2/*.*'))
        train_image_list_x3 = sorted(glob.glob('./dataset/training/X3/*.*'))
        train_image_list_x4 = sorted(glob.glob('./dataset/training/X4/*.*'))
        train_label_list = sorted(glob.glob('./dataset/training/gray/*.*'))

        num_image = len(train_label_list)

        sr_model = VDSR(channel_length=self.c_length, image=self.x)
        prediction, _, l2_loss = sr_model.build_model()

        learning_rate = tf.placeholder(dtype='float32', name='learning_rate')

        with tf.name_scope("mse_loss"):
            loss = tf.reduce_mean(tf.square(self.y - prediction))
            loss += 1e-4 * l2_loss

        train_op = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(loss)
        # optimize = tf.train.AdamOptimizer(learning_rate=learning_rate, momentum=0.9)
        '''
        # gradient clipping = Adam can handle by itself
        gvs = optimize.compute_gradients(loss=loss)
        capped_gvs = [(tf.clip_by_value(grad, -10./learning_rate, 10./learning_rate), var) for grad, var in gvs]
        train_op = optimize.apply_gradients(capped_gvs)
        '''
        batch_size = 3
        num_batch = int((num_image - 1) / batch_size) + 1
        print(num_batch)

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver(max_to_keep=2)
        if self.pre_trained:
            saver.restore(sess, self.save_path)

        lr = 1e-3

        for i in range(iteration):
            total_loss = 0  # mse + l2
            total_l2 = 0
            if i % 20 == 19:
                lr = lr * 0.9
            for j in range(num_batch):
                for k in range(3):
                    if k == 0:
                        batch_image, batch_label = preprocess.load_data(
                            train_image_list_x2,
                            train_label_list, j * batch_size,
                            min((j + 1) * batch_size, num_image),
                            self.patch_size, self.num_patch_per_image)
                    if k == 1:
                        batch_image, batch_label = preprocess.load_data(
                            train_image_list_x3,
                            train_label_list, j * batch_size,
                            min((j + 1) * batch_size, num_image),
                            self.patch_size, self.num_patch_per_image)
                    if k == 2:
                        batch_image, batch_label = preprocess.load_data(
                            train_image_list_x4,
                            train_label_list, j * batch_size,
                            min((j + 1) * batch_size, num_image),
                            self.patch_size, self.num_patch_per_image)

                    l2, losses, _ = sess.run(
                        [l2_loss, loss, train_op],
                        feed_dict={
                            self.x: batch_image,
                            self.y: batch_label,
                            learning_rate: lr
                        })
                    total_loss += losses / (num_batch * 3)
                    total_l2 += 1e-4 * l2 / (num_batch * 3)

            print('In', '%04d' % (i + 1), 'epoch, current loss is',
                  '{:.5f}'.format(total_loss - total_l2),
                  '{:.5f}'.format(total_l2))
            saver.save(sess, save_path=self.save_path)

        print('Train completed')
Beispiel #12
0
# =======================================================
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    config = tf.ConfigProto()
    if args.gpu == -1: config.device_count = {'GPU': 0}
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    #config.operation_timeout_in_ms=10000

    g = tf.Graph()
    g.as_default()
    with tf.Session(graph=g, config=config) as sess:
        # -----------------------------------
        # build model
        # -----------------------------------
        model_path = args.checkpoint_dir
        vdsr = VDSR(sess, args=args)

        # -----------------------------------
        # train, test, inferecnce
        # -----------------------------------
        if args.mode == "train":
            vdsr.train()

        elif args.mode == "test":
            vdsr.test()

        elif args.mode == "inference":
            #load image
            image_path = os.path.join(os.getcwd(), "test", args.infer_subdir,
                                      args.infer_imgpath)
            infer_image = plt.imread(image_path)
Beispiel #13
0
    if not os.path.exists(opts.weights_dir):
        os.mkdir(opts.weights_dir)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    torch.manual_seed(42)

    if opts.sr_module == "FSRCNN":
        sr_module = FSRCNN(scale=opts.scale).to(device)
    elif opts.sr_module == "ESPCN":
        sr_module = ESPCN(scale=opts.scale).to(device)
    elif opts.sr_module == "VDSR":
        sr_module = VDSR(scale=opts.scale).to(device)
    else:
        sr_module = FSRCNN(scale=opts.scale).to(device)

    if opts.lr_module == "FLRCNN":
        lr_module = FLRCNN(scale=opts.scale).to(device)
    elif opts.lr_module == "DESPCN":
        lr_module = DESPCN(scale=opts.scale).to(device)
    elif opts.lr_module == "DVDSR":
        lr_module = DVDSR(scale=opts.scale).to(device)
    else:
        lr_module = FLRCNN(scale=opts.scale).to(device)

    criterion = nn.MSELoss()

    optimizer = optim.Adam([{
Beispiel #14
0
        model = FSRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "FALSR_A" or opt.model == "FALSR_B":
        if opt.upscale is not 2:
            raise ("ONLY SUPPORT 2X")
        else:
            if opt.model == "FALSR_A":
                model = FALSR_A()
            if opt.model == "FALSR_B":
                model = FALSR_B()

    if opt.model == "SRCNN" and opt.upscale == 4:
        model = SRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "VDSR" and opt.upscale == 4:
        model = VDSR(num_channels=3, base_channels=3, num_residual=20)

    if opt.model == "ESPCN" and opt.upscale == 4:
        model = ESPCN(num_channels=3, feature=64, upscale_factor=4)

if opt.criterion:
    if opt.criterion == "l1":
        criterion = nn.L1Loss()
    if opt.criterion == "l2":
        criterion = nn.MSELoss()
    if opt.criterion == "custom":
        pass

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()
Beispiel #15
0
                        required=True)
    parser.add_argument("--scale", type=int, default=2)

    opts = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if opts.sr_module == "FSRCNN":
        sr_module = FSRCNN(scale=opts.scale)
    elif opts.sr_module == "ESPCN":
        sr_module = ESPCN(scale=opts.scale)
    elif opts.sr_module == "VDSR":
        sr_module = VDSR(scale=opts.scale)
    else:
        sr_module = FSRCNN(scale=opts.scale)

    if opts.lr_module == "FLRCNN":
        lr_module = FLRCNN(scale=opts.scale)
    elif opts.lr_module == "DESPCN":
        lr_module = DESPCN(scale=opts.scale)
    elif opts.lr_module == "DVDSR":
        lr_module = DVDSR(scale=opts.scale)
    else:
        lr_module = FLRCNN(scale=opts.scale)

    sr_module = sr_module.to(device)
    lr_module = lr_module.to(device)
    sr_module.eval()
Beispiel #16
0
print('===> Loading datasets')

train_dataloader = DataLoader(dataset=DatasetFromHdf5(scale=4,train=True, filename = 'celebA_train_s4.h5'), num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
test_dataloader = DataLoader(dataset=DatasetFromHdf5(scale=4,train=False,filename = 'celebA_test2_s4.h5'), num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

print('===> Building model')
Hmodel = Net(level = opt.level)
Lmodel = Net(level = opt.level)
if opt.level == 0:
    curL_in=1 #level 0: 1, level 1: 16, level 2: 496
if opt.level == 1:
    curL_in=16
if opt.level == 2:
    curL_in=496
#model_level1 = torch.load(os.path.join('./SR_v2', "model_epoch_200.pth"))
model = VDSR(curL_in=curL_in,filter_num = opt.f_size)#ConvertNet(curL_in=50,receptive_size=4)
#model = model.load_state_dict(os.path.join(path, "model_epoch_200.pth"))
Hmodel.init_weight_h()
Lmodel.init_weight_l()
criterion2 = nn.MSELoss()
criterion = nn.MSELoss(size_average=False)
    
if cuda:
    Hmodel = Hmodel.cuda()
    Lmodel = Lmodel.cuda()
    model = model.cuda()
    #model_level1 = model_level1.cuda()
    criterion = criterion.cuda()
    criterion2 = criterion2.cuda()
optimizer = optim.Adam(model.parameters(),lr=opt.lr)    
#Hoptimizer = optim.Adam([{'params': Hmodel.conv1_b.parameters(),'lr':0},{'params': Hmodel.conv2_b.parameters(),'lr': 0},
import torch
import matplotlib.pyplot as plt
from model import VDSR
import cv2
import torchvision.transforms as T
import numpy as np
import math

device = torch.device('cuda:0')
transform = T.ToTensor()
net = VDSR()
checkpoint = torch.load('D:/VDSR_SGD_epoch_60.pth')
net.load_state_dict(checkpoint['model_state_dict'])
net = net.to(device)
net.eval()
image_path = 'D:/train_data/91/000tt16.bmp'
img = cv2.imread(image_path)
img_r = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)

img = cv2.resize(img, (img.shape[1], img.shape[0] // 2),
                 interpolation=cv2.INTER_CUBIC)

#img_original=img_r[200:230,300:330]
Y, Cr, Cb = cv2.split(img)

patch = Y[200:230, 300:330]

plt.imshow(img_r)
plt.show()
img = transform(Y)
Beispiel #18
0
import numpy as np
import torch.optim as optim
from data_utils import DatasetFromFolder
from tensorboardX import SummaryWriter
from model import VDSR

device=torch.device('cuda:0')
writer=SummaryWriter('D:/VDSR')

transform=T.ToTensor()

trainset=DatasetFromFolder('D:/train_data/291',transform=transform)
trainLoader=DataLoader(trainset,batch_size=128,shuffle=True)


net=VDSR()
net=net.to(device)

optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9,weight_decay=1e-4)
scheduler=optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)
criterion=nn.MSELoss()
criterion=criterion.to(device)

net.train()
for epoch in range(20):

    running_cost=0.0
    for i,data in enumerate (trainLoader,0):
        input,target=data
        input,target=input.to(device),target.to(device)
        optimizer.zero_grad()