def main(): parser = argparse.ArgumentParser(description="Vanilla_AE") parser.add_argument("--batchsize", "-b", type=int, default=128) parser.add_argument("--epoch", "-e", type=int, default=100) parser.add_argument("--gpu", "-g", type=int, default=0) parser.add_argument("--snapshot", "-s", type=int, default=10) parser.add_argument("--n_dimz", "-z", type=int, default=64) args = parser.parse_args() #print settings print("GPU:{}".format(args.gpu)) print("epoch:{}".format(args.epoch)) print("Minibatch_size:{}".format(args.batchsize)) print('') batchsize = args.batchsize gpu_id = args.gpu max_epoch = args.epoch train_val, test = mnist.get_mnist(withlabel=False, ndim=1) train, valid = split_dataset_random(train_val, 50000, seed=0) model = Network.AE(n_dimz=args.n_dimz, n_out=784) #set iterator train_iter = iterators.SerialIterator(train, batchsize) valid_iter = iterators.SerialIterator(valid, batchsize, repeat=False, shuffle=False) #optimizer def make_optimizer(model, alpha=0.0002, beta1=0.5): optimizer = optimizers.Adam(alpha=alpha, beta1=beta1) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001)) return optimizer opt = make_optimizer(model) #trainer updater = Updater.AEUpdater(model=model, iterator=train_iter, optimizer=opt, device=args.gpu) trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='result') #trainer.extend(extensions.ExponentialShift('lr', 0.5),trigger=(30, 'epoch')) trainer.extend(extensions.LogReport(log_name='log')) trainer.extend( Evaluator.AEEvaluator(iterator=valid_iter, target=model, device=args.gpu)) trainer.extend(extensions.snapshot_object( model, filename='model_snapshot_epoch_{.updater.epoch}.npz'), trigger=(args.snapshot, 'epoch')) #trainer.extend(extensions.snapshot_object(optimizer, filename='optimizer_snapshot_epoch_{.updater.epoch}'), trigger=(args.snapshot, 'epoch')) trainer.extend( extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss'])) trainer.extend(extensions.ProgressBar()) trainer.run() del trainer
def main(): parser = argparse.ArgumentParser(description="Vanilla_AE") parser.add_argument("--batchsize", "-b", type=int, default=64) parser.add_argument("--epoch", "-e", type=int, default=100) parser.add_argument("--gpu", "-g", type=int, default=0) parser.add_argument("--snapshot", "-s", type=int, default=10) parser.add_argument("--n_dimz", "-z", type=int, default=16) parser.add_argument("--dataset", "-d", type=str, default='mnist') parser.add_argument("--network", "-n", type=str, default='conv') args = parser.parse_args() def transform(in_data): img = in_data img = resize(img, (32, 32)) return img def transform2(in_data): img, label = in_data img = resize(img, (32, 32)) return img, label #import program import Updater import Visualizer #print settings print("GPU:{}".format(args.gpu)) print("epoch:{}".format(args.epoch)) print("Minibatch_size:{}".format(args.batchsize)) print('') out = os.path.join('result', args.network) batchsize = args.batchsize gpu_id = args.gpu max_epoch = args.epoch train_val, _ = mnist.get_mnist(withlabel=False, ndim=3) train_val = TransformDataset(train_val, transform) #for visualize _, test = mnist.get_mnist(withlabel=True, ndim=3) test = TransformDataset(test, transform2) label1 = 1 label2 = 5 test1 = [i[0] for i in test if (i[1] == label1)] test2 = [i[0] for i in test if (i[1] == label2)] test1 = test1[0:5] test2 = test2[5:10] if args.network == 'conv': import Network.mnist_conv as Network elif args.network == 'fl': import Network.mnist_fl as Network else: raise Exception('Error!') AE = Network.AE(n_dimz=args.n_dimz, batchsize=args.batchsize) train, valid = split_dataset_random(train_val, 50000, seed=0) #set iterator train_iter = iterators.SerialIterator(train, batchsize) valid_iter = iterators.SerialIterator(valid, batchsize, repeat=False, shuffle=False) #optimizer def make_optimizer(model, alpha=0.0002, beta1=0.5): optimizer = optimizers.Adam(alpha=alpha, beta1=beta1) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001)) return optimizer opt_AE = make_optimizer(AE) #trainer updater = Updater.AEUpdater(model=(AE), iterator=train_iter, optimizer={'AE': opt_AE}, device=args.gpu) trainer = training.Trainer(updater, (max_epoch, 'epoch'), out=out) trainer.extend(extensions.LogReport(log_name='log')) snapshot_interval = (args.snapshot, 'epoch') display_interval = (1, 'epoch') trainer.extend(extensions.snapshot_object( AE, filename='AE_snapshot_epoch_{.updater.epoch}.npz'), trigger=snapshot_interval) trainer.extend(extensions.PrintReport(['epoch', 'AE_loss']), trigger=display_interval) trainer.extend(extensions.ProgressBar()) trainer.extend(Visualizer.out_generated_image(AE, test1, test2, out), trigger=(1, 'epoch')) trainer.run() del trainer