示例#1
0
        cifar_net = net.Inception()
    elif args.model == 'pyramid':
        cifar_net = net.PyramidNet(args.res_depth, skip=args.skip_depth)
    elif args.model == 'shake_residual':
        cifar_net = net.ShakeShakeResidualNet(args.res_depth, args.res_width)
    else:
        cifar_net = net.VGG()

    if args.optimizer == 'sgd':
        optimizer = optimizers.MomentumSGD(lr=args.lr)
    else:
        optimizer = optimizers.Adam(alpha=args.alpha)
    optimizer.setup(cifar_net)
    if args.weight_decay > 0:
        optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))
    cifar_trainer = trainer.CifarTrainer(cifar_net, optimizer, args.iter, args.batch_size, args.gpu, lr_shape=args.lr_shape, lr_decay=lr_decay_iter)
    if args.prefix is None:
        model_prefix = '{}_{}'.format(args.model, args.optimizer)
    else:
        model_prefix = args.prefix

    state = {'best_valid_error': 100, 'best_test_error': 100, 'clock': time.clock()}
    def on_epoch_done(epoch, n, o, loss, acc, valid_loss, valid_acc, test_loss, test_acc, test_time):
        error = 100 * (1 - acc)
        print('epoch {} done'.format(epoch))
        print('train loss: {} error: {}'.format(loss, error))
        if valid_loss is not None:
            valid_error = 100 * (1 - valid_acc)
            print('valid loss: {} error: {}'.format(valid_loss, valid_error))
        else:
            valid_error = None
示例#2
0
文件: train.py 项目: zghzdxs/GUINNESS
    elif args.optimizer == 'momentum':
        print("optimizer: momentum SGD")
        optimizer = optimizers.MomentumSGD(lr=args.lr)
    elif args.optimizer == 'delta':
        print("optimizer: AdaDelta")
        optimizer = optimizers.AdaDelta()
    else:
        print("optimizer: Adam")
        optimizer = optimizers.Adam(alpha=args.alpha)
    optimizer.setup(cifar_net)
    if args.weight_decay > 0:
        optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    optimizer.add_hook(weight_clip.WeightClip())

    cifar_trainer = trainer.CifarTrainer(cifar_net, optimizer, args.iter, args.batch_size, args.gpu)

    state = {'best_valid_error': 100, 'best_test_error': 100, 'clock': time.clock()}
    def on_epoch_done(epoch, n, o, loss, acc, valid_loss, valid_acc, test_loss, test_acc):
        error = 100 * (1 - acc)
        valid_error = 100 * (1 - valid_acc)
        test_error = 100 * (1 - test_acc)
        print('epoch {} done'.format(epoch))
        print('train loss: {} error: {}'.format(loss, error))
        print('valid loss: {} error: {}'.format(valid_loss, valid_error))
        print('test  loss: {} error: {}'.format(test_loss, test_error))
        if valid_error < state['best_valid_error']:
            serializers.save_npz('{}.model'.format(model_prefix), n)
            serializers.save_npz('{}.state'.format(model_prefix), o)
            state['best_valid_error'] = valid_error
            state['best_test_error'] = test_error
示例#3
0
def main():
    args = _parse_args()

    np.random.seed(args.seed)
    if args.prefix is None:
        model_prefix = os.path.basename(args.structure_path)
        model_prefix = os.path.splitext(model_prefix)[0]
    else:
        model_prefix = args.prefix
    log_file_path = os.path.join('model', '{}_log.csv'.format(model_prefix))
    lr_decay_epoch = map(int, args.lr_decay_epoch.split(','))

    print('loading dataset...')
    train_data, test_data = chainer.datasets.get_cifar10()
    if args.no_valid_data:
        valid_data = None
    else:
        train_data, valid_data = chainer.datasets.split_dataset_random(
            train_data, 45000)
    train_data = chainer.datasets.TransformDataset(train_data, _transform)
    test_data = chainer.datasets.TransformDataset(test_data, _subtract_mean)
    if valid_data is not None:
        valid_data = chainer.datasets.TransformDataset(valid_data,
                                                       _subtract_mean)

    print('start training')
    with open(args.structure_path) as f:
        output_sizes = json.load(f)
    cifar_net = net.VGG(output_sizes)
    if args.model is not None:
        serializers.load_npz(args.model, cifar_net)

    if args.optimizer == 'sgd':
        optimizer = optimizers.MomentumSGD(lr=args.lr)
    else:
        optimizer = optimizers.Adam(alpha=args.alpha)
    optimizer.setup(cifar_net)
    if args.lambda_value > 0:
        _add_hook_to_gamma(cifar_net,
                           chainer.optimizer.Lasso(args.lambda_value))
    if args.weight_decay > 0:
        optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))
    cifar_trainer = trainer.CifarTrainer(cifar_net,
                                         optimizer,
                                         args.epoch,
                                         args.batch_size,
                                         args.gpu,
                                         lr_shape=args.lr_shape,
                                         lr_decay=lr_decay_epoch)

    state = {
        'best_valid_error': 100,
        'best_test_error': 100,
        'clock': time.clock()
    }

    def on_epoch_done(epoch, n, o, loss, acc, valid_loss, valid_acc, test_loss,
                      test_acc, test_time):
        error = 100 * (1 - acc)
        print('epoch {} done'.format(epoch))
        print('train loss: {} error: {}'.format(loss, error))
        if valid_loss is not None:
            valid_error = 100 * (1 - valid_acc)
            print('valid loss: {} error: {}'.format(valid_loss, valid_error))
        else:
            valid_error = None
        if test_loss is not None:
            test_error = 100 * (1 - test_acc)
            print('test  loss: {} error: {}'.format(test_loss, test_error))
            print('test time: {}s'.format(test_time))
        else:
            test_error = None
        if valid_loss is not None and valid_error < state['best_valid_error']:
            save_path = os.path.join('model', '{}.model'.format(model_prefix))
            serializers.save_npz(save_path, n)
            save_path = os.path.join('model', '{}.state'.format(model_prefix))
            serializers.save_npz(save_path, o)
            state['best_valid_error'] = valid_error
            state['best_test_error'] = test_error
        elif valid_loss is None:
            save_path = os.path.join('model', '{}.model'.format(model_prefix))
            serializers.save_npz(save_path, n)
            save_path = os.path.join('model', '{}.state'.format(model_prefix))
            serializers.save_npz(save_path, o)
            state['best_test_error'] = test_error
        if args.save_epoch > 0 and (epoch + 1) % args.save_epoch == 0:
            save_path = os.path.join(
                'model', '{}_{}.model'.format(model_prefix, epoch + 1))
            serializers.save_npz(save_path, n)
            save_path = os.path.join(
                'model', '{}_{}.state'.format(model_prefix, epoch + 1))
            serializers.save_npz(save_path, o)
        clock = time.clock()
        print('elapsed time: {}'.format(clock - state['clock']))
        state['clock'] = clock
        with open(log_file_path, 'a') as f:
            f.write('{},{},{},{},{},{},{}\n'.format(epoch, loss, error,
                                                    valid_loss, valid_error,
                                                    test_loss, test_error))

    with open(log_file_path, 'w') as f:
        f.write(
            'epoch,train loss,train acc,valid loss,valid acc,test loss,test acc\n'
        )
    cifar_trainer.fit(train_data, valid_data, test_data, on_epoch_done)

    print('best test error: {}'.format(state['best_test_error']))

    train_loss, train_acc, test_loss, test_acc = np.loadtxt(
        log_file_path,
        delimiter=',',
        skiprows=1,
        usecols=[1, 2, 5, 6],
        unpack=True)
    epoch = len(train_loss)
    xs = np.arange(epoch, dtype=np.int32) + 1
    plt.clf()
    fig, ax = plt.subplots()
    ax.plot(xs, train_loss, label='train loss', c='blue')
    ax.plot(xs, test_loss, label='test loss', c='red')
    ax.set_xlim((1, epoch))
    ax.set_xlabel('epoch')
    ax.set_ylabel('loss')
    ax.legend(loc='upper right')
    save_path = os.path.join('model', '{}_loss.png'.format(model_prefix))
    plt.savefig(save_path, bbox_inches='tight')

    plt.clf()
    fig, ax = plt.subplots()
    ax.plot(xs, train_acc, label='train error', c='blue')
    ax.plot(xs, test_acc, label='test error', c='red')
    ax.set_xlim([1, epoch])
    ax.set_xlabel('epoch')
    ax.set_ylabel('error')
    ax.legend(loc='upper right')
    save_path = os.path.join('model', '{}_error'.format(model_prefix))
    plt.savefig(save_path, bbox_inches='tight')