Example #1
0
                        help='the output folder')
    parser.add_argument('--model_name',
                        type=str,
                        default='model',
                        help='the name of the model')

    args = parser.parse_args()
    args.batch_size_test = args.batch_size if args.batch_size_test == None else args.batch_size_test
    config_visible_gpu(args.gpu)

    train_loader, test_loader = mnist(batch_size=args.batch_size,
                                      batch_size_test=args.batch_size_test)

    model = ConvNet1(input_size=[28, 28], input_channels=1, output_class=10)

    device_ids, model = parse_device_alloc(device_config=None, model=model)

    lr_func = parse_lr(policy=args.lr_policy, epoch_num=args.epoch_num)
    optimizer = parse_optim(policy=args.optim_policy,
                            params=model.parameters())

    setup_config = {kwarg: value for kwarg, value in args._get_kwargs()}
    setup_config['lr_list'] = [lr_func(idx) for idx in range(args.epoch_num)]
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)

    tricks = {}
    if args.snapshots != None:
        tricks['snapshots'] = args.snapshots

    results = train_test(setup_config=setup_config,
Example #2
0
    config_visible_gpu(args.gpu)
    use_gpu = torch.cuda.is_available() and args.gpu not in ['cpu']

    if args.output_folder == None:
        raise ValueError('The output folder cannot be None')
    if args.model_name == None:
        raise ValueError('The name of the model cannot be None')
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)

    netG = MNIST_Generator()
    netD = MNIST_Discriminator()
    netG.weight_init()
    netD.weight_init()

    _, netG = parse_device_alloc(device_config=None, model=netG)
    _, netD = parse_device_alloc(device_config=None, model=netD)

    optimG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

    data = mnist(args.batch_size)

    plot_func = lambda output: plot_mnist_data(true_rows=2,
                                               fake_rows=8,
                                               cols=10,
                                               data=data,
                                               netG=netG,
                                               input_dim=128,
                                               use_gpu=use_gpu,
                                               output_file=output)