def main(): parser = argparse.ArgumentParser(description='chainer line drawing colorization') parser.add_argument('--batchsize', '-b', type=int, default=16, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') parser.add_argument('--dataset', '-i', default='./images/', help='Directory of image files.') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--seed', type=int, default=0, help='Random seed') parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') args = parser.parse_args() print('GPU: {}'.format(args.gpu)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') root = args.dataset #model = "./model_paint" cnn = unet.UNET() #serializers.load_npz("result/model_iter_10000", cnn) dis = unet.DIS() #serializers.load_npz("result/model_dis_iter_20000", dis) l = lnet.LNET() serializers.load_npz("models/liner_f", l) dataset = Image2ImageDataset("dat/images_color_train.dat",root+"line/",root+"color/", train=True) #dataset.set_img_dict(img_dict) train_iter = chainer.iterators.SerialIterator( dataset , args.batchsize) if args.gpu >= 0: chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current cnn.to_gpu() # Copy the model to the GPU dis.to_gpu() # Copy the model to the GPU l.to_gpu() # Setup optimizer parameters. opt = optimizers.Adam(alpha=0.0001) opt.setup(cnn) opt.add_hook(chainer.optimizer.WeightDecay(1e-5), 'hook_cnn') opt_d = chainer.optimizers.Adam(alpha=0.0001) opt_d.setup(dis) opt_d.add_hook(chainer.optimizer.WeightDecay(1e-5), 'hook_dec') # Set up a trainer updater = ganUpdater( models=(cnn, dis, l), iterator={ 'main': train_iter, #'test': test_iter }, optimizer={ 'cnn': opt, 'dis': opt_d}, device=args.gpu) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) snapshot_interval = (args.snapshot_interval, 'iteration') snapshot_interval2 = (args.snapshot_interval*2, 'iteration') trainer.extend(extensions.dump_graph('cnn/loss')) trainer.extend(extensions.snapshot(), trigger=snapshot_interval2) trainer.extend(extensions.snapshot_object( cnn, 'cnn_128_iter_{.updater.iteration}'), trigger=snapshot_interval) trainer.extend(extensions.snapshot_object( dis, 'cnn_128_dis_iter_{.updater.iteration}'), trigger=snapshot_interval) trainer.extend(extensions.snapshot_object( opt, 'optimizer_'), trigger=snapshot_interval) trainer.extend(extensions.LogReport( trigger=(10, 'iteration'), )) trainer.extend(extensions.PrintReport( ['epoch', 'cnn/loss', 'cnn/loss_rec','cnn/loss_adv','cnn/loss_tag','cnn/loss_l','dis/loss' ])) trainer.extend(extensions.ProgressBar(update_interval=20)) trainer.run() if args.resume: # Resume from a snapshot chainer.serializers.load_npz(args.resume, trainer) # Save the trained model chainer.serializers.save_npz(os.path.join(out_dir, 'model_final'), cnn) chainer.serializers.save_npz(os.path.join(out_dir, 'optimizer_final'), opt)
def main(): parser = argparse.ArgumentParser( description='chainer line drawing colorization') parser.add_argument('--batchsize', '-b', type=int, default=16, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument( '--dataset', '-i', default='/media/ljw/Research/research/Deep_Learning/data/Places2/', help='Directory of image files.') # parser.add_argument('--dataset', '-i', default='/home/ljw/deep_learning/intercolorize/data/farm/', # help='Directory of image files.') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--seed', type=int, default=0, help='Random seed') parser.add_argument('--snapshot_interval', type=int, default=5000, help='Interval of snapshot') parser.add_argument('--display_interval', type=int, default=100, help='Interval of displaying log to console') parser.add_argument('--colormode', default='LAB', help='Color mode') args = parser.parse_args() print('GPU: {}'.format(args.gpu)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') root = args.dataset #model = "./model_paint" if args.colormode == 'YUV': cnn = unet.UNET() dis = unet.DIS() elif args.colormode == 'LAB': cnn = unet.UNET(inputChannel=3, outputChannel=2) dis = unet.DIS(inputChannel=2) else: print('ERROR! Unexpected color mode!!!') # l = lnet.LNET() # serializers.load_npz("../models/liner_f", l) # load pre-trained model to l dataset = Image2ImageDataset( "/media/ljw/Research/research/Deep_Learning/data/Places2/filelist_places365-standard/places365_train_outdoor_color512-all.txt", root + "/", root + "data_large", train=True, colormode=args.colormode) # the class of dataset # dataset = Image2ImageDataset( # "/home/ljw/deep_learning/intercolorize/data/farm/color_512.txt", # root + "gray/", root + "color/", train=True, colormode=args.colormode) # the class of dataset train_iter = chainer.iterators.SerialIterator(dataset, args.batchsize) if args.gpu >= 0: chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current cnn.to_gpu() # Copy the model to the GPU dis.to_gpu() # Copy the model to the GPU # l.to_gpu() # Setup optimizer parameters. opt = optimizers.Adam(alpha=0.0001) # use the Adam opt.setup(cnn) opt.add_hook(chainer.optimizer.WeightDecay(1e-5), 'hook_cnn') # what does this used for??? opt_d = chainer.optimizers.Adam(alpha=0.0001) opt_d.setup(dis) opt_d.add_hook(chainer.optimizer.WeightDecay(1e-5), 'hook_dec') # Set up a trainer updater = ganUpdater( colormode=args.colormode, models=(cnn, dis), iterator={ 'main': train_iter, #'test': test_iter }, optimizer={ 'cnn': opt, 'dis': opt_d }, device=args.gpu) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) snapshot_interval = (args.snapshot_interval, 'iteration') snapshot_interval2 = (args.snapshot_interval * 2, 'iteration') trainer.extend(extensions.dump_graph('cnn/loss')) trainer.extend(extensions.snapshot(), trigger=snapshot_interval2) trainer.extend(extensions.snapshot_object( cnn, 'cnn_128_iter_{.updater.iteration}'), trigger=snapshot_interval) trainer.extend(extensions.snapshot_object( dis, 'cnn_128_dis_iter_{.updater.iteration}'), trigger=snapshot_interval) trainer.extend(extensions.snapshot_object(opt, 'optimizer_'), trigger=snapshot_interval) trainer.extend(extensions.LogReport(trigger=(10, 'iteration'), )) trainer.extend( extensions.PrintReport([ 'epoch', 'cnn/loss', 'cnn/loss_rec', 'cnn/loss_adv', 'cnn/loss_tag', 'cnn/loss_l', 'dis/loss' ])) trainer.extend(extensions.ProgressBar(update_interval=10)) if args.resume: # Resume from a snapshot chainer.serializers.load_npz(os.path.join(args.out, args.resume), trainer) trainer.run() # Save the trained model chainer.serializers.save_npz(os.path.join(args.out, 'model_final'), cnn) chainer.serializers.save_npz(os.path.join(args.out, 'optimizer_final'), opt)