Exemplo n.º 1
0
def main(_):  #?
    with tf.Session() as sess:
        vdsr = VDSR(sess,
                    image_size=FLAGS.image_size,
                    label_size=FLAGS.label_size,
                    layer=FLAGS.layer,
                    c_dim=FLAGS.c_dim)

        vdsr.train(FLAGS)
Exemplo n.º 2
0
def main(_): 
    with tf.Session() as sess:
        vdsr = VDSR(sess,
                      image_size = FLAGS.image_size,
                      label_size = FLAGS.label_size,
                      layer = FLAGS.layer,
                      c_dim = FLAGS.c_dim)
	if FLAGS.is_train:
           vdsr.train(FLAGS)
	else:
	   FLAGS.c_dim = 3
	   vdsr.test(FLAGS)
Exemplo n.º 3
0
def main():

    if not os.path.exists(Config.checkpoint_dir):
        os.makedirs(Config.checkpoint_dir)

    with tf.Session() as sess:
        trysr = VDSR(sess,
                     image_size=Config.image_size,
                     label_size=Config.label_size,
                     batch_size=Config.batch_size,
                     c_dim=Config.c_dim,
                     checkpoint_dir=Config.checkpoint_dir,
                     scale=Config.scale)

        trysr.train(Config)
Exemplo n.º 4
0
    criterion = nn.MSELoss()

    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

    train_dataset = TrainDataset(args.train_file)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
    eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    for epoch in range(args.num_epochs):
        model.train()
        loss_sum = 0

        if epoch < 20:
            learning_rate = args.lr
        elif epoch >= 20 and epoch < 40:
            learning_rate = 0.001
        elif epoch >= 40 and epoch < 60:
            learning_rate = 1e-2 * args.lr
        else:
            learning_rate = 1e-3 * args.lr

        print("lr = {}".format(learning_rate))

        optimizer = optim.SGD(model.parameters(),
                              lr=learning_rate,
Exemplo n.º 5
0
    #config.operation_timeout_in_ms=10000

    g = tf.Graph()
    g.as_default()
    with tf.Session(graph=g, config=config) as sess:
        # -----------------------------------
        # build model
        # -----------------------------------
        model_path = args.checkpoint_dir
        vdsr = VDSR(sess, args=args)

        # -----------------------------------
        # train, test, inferecnce
        # -----------------------------------
        if args.mode == "train":
            vdsr.train()

        elif args.mode == "test":
            vdsr.test()

        elif args.mode == "inference":
            #load image
            image_path = os.path.join(os.getcwd(), "test", args.infer_subdir,
                                      args.infer_imgpath)
            infer_image = plt.imread(image_path)
            if np.max(infer_image) > 1: infer_image = infer_image / 255
            infer_image = imresize(infer_image,
                                   scalar_scale=1,
                                   output_shape=None,
                                   mode="vec")
Exemplo n.º 6
0
transform=T.ToTensor()

trainset=DatasetFromFolder('D:/train_data/291',transform=transform)
trainLoader=DataLoader(trainset,batch_size=128,shuffle=True)


net=VDSR()
net=net.to(device)

optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9,weight_decay=1e-4)
scheduler=optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)
criterion=nn.MSELoss()
criterion=criterion.to(device)

net.train()
for epoch in range(20):

    running_cost=0.0
    for i,data in enumerate (trainLoader,0):
        input,target=data
        input,target=input.to(device),target.to(device)
        optimizer.zero_grad()
        output=net(input)
        loss=criterion(output,target)
        loss.backward()
        if optimizer=='SGD':
            nn.utils.clip_grad_norm(net.parameters(),0.4)
        optimizer.step()
        running_cost+=loss.item()
        torch.save(net.state_dict(),'VDSR.pth')