def train_phase(predictor, train, valid, args):

    # setup iterators
    train_iter = iterators.SerialIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    # setup a model
    device = torch.device(args.gpu)

    model = Classifier(predictor)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    trainer.extend(extensions.LogReport())

    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy', 'elapsed_time'
        ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    trainer.run()

    torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
def train_phase(predictor, train, valid, args):

    print('# classes:', train.n_classes)
    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_iter = iterators.MultiprocessIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=True)

    # setup a model
    class_weight = None  # NOTE: please set if you have..

    lossfun = partial(softmax_cross_entropy,
                      normalize=False,
                      class_weight=class_weight)

    device = torch.device(args.gpu)

    model = Classifier(predictor, lossfun=lossfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               out=args.out)

    frequency = max(args.iteration //
                    20, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(
        monitor='validation/main/loss',
        max_trigger=(args.iteration, 'iteration'),
        check_trigger=(frequency, 'iteration'),
        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # setup a visualizer
    transforms = {
        'x': lambda x: x,
        'y': lambda x: np.argmax(x, axis=0),
        't': lambda x: x
    }

    cmap = np.array([[0, 0, 0], [0, 0, 1]])
    cmaps = {'x': None, 'y': cmap, 't': cmap}

    clims = {'x': 'minmax', 'y': None, 't': None}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=cmaps,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter,
                             model,
                             valid_file,
                             visualizer=visualizer,
                             n_vis=20,
                             device=args.gpu),
                   trigger=(frequency, 'iteration'))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    trainer.extend(extensions.snapshot(
        filename='snapshot_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        predictor, 'predictor_iter_{.updater.iteration:08}.pth'),
                   trigger=(frequency, 'iteration'))

    log_keys = [
        'main/loss', 'validation/main/loss', 'main/accuracy',
        'validation/main/accuracy'
    ]

    trainer.extend(LogReport(keys=log_keys))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [
                key for key in log_keys
                if key.split('/')[-1].startswith(plot_key)
            ]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                      'iteration',
                                      file_name=plot_key + '.png',
                                      trigger=(frequency, 'iteration')))

    trainer.extend(
        PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=100))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    # train
    trainer.run()
Пример #3
0
def train_phase(predictor, train, valid, args):

    # visualize
    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(13, 5))
    ax = sns.scatterplot(x=train.x.ravel(),
                         y=train.y.ravel(),
                         color='blue',
                         s=55,
                         alpha=0.3)
    ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(-10, 10)
    ax.set_ylim(-15, 15)
    plt.legend(['Ground-truth', 'Observation'])
    plt.title('Training data set')
    plt.tight_layout()
    plt.savefig(os.path.join(args.out, 'train_dataset.png'))
    plt.close()

    # setup iterators
    train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    # setup a model
    device = torch.device(args.gpu)

    lossfun = noised_mean_squared_error
    accfun = lambda y, t: F.l1_loss(y[0], t)

    model = Regressor(predictor, lossfun=lossfun, accfun=accfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    trainer.extend(extensions.LogReport())

    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

        trainer.extend(
            extensions.PlotReport(
                ['main/predictor/sigma', 'validation/main/predictor/sigma'],
                'epoch',
                file_name='sigma.png'))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy',
            'main/predictor/sigma', 'validation/main/predictor/sigma',
            'elapsed_time'
        ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    trainer.run()

    torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
Пример #4
0
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency',
                        '-f',
                        type=int,
                        default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--device',
                        '-d',
                        type=str,
                        default='-1',
                        help='Device specifier. Either ChainerX device '
                        'specifier or an integer. If non-negative integer, '
                        'CuPy arrays with specified device id are used. If '
                        'negative integer, NumPy arrays are used')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        type=str,
                        help='Resume the training from snapshot')
    parser.add_argument('--autoload',
                        action='store_true',
                        help='Automatically load trainer snapshots in case'
                        ' of preemption or other temporary system failure')
    parser.add_argument('--unit',
                        '-u',
                        type=int,
                        default=1000,
                        help='Number of units')
    group = parser.add_argument_group('deprecated arguments')
    group.add_argument('--gpu',
                       '-g',
                       dest='device',
                       type=int,
                       nargs='?',
                       const=0,
                       help='GPU ID (negative value indicates CPU)')
    args = parser.parse_args()

    device = torch.device(args.device)

    print('Device: {}'.format(device))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    model = Classifier(MLP(784, args.unit, 10))
    model.to(device)

    # Setup an optimizer
    optimizer = torch.optim.Adam(model.parameters())

    # Load the MNIST dataset
    transform = transforms.ToTensor()
    train = datasets.MNIST('data',
                           train=True,
                           download=True,
                           transform=transform)
    test = datasets.MNIST('data', train=False, transform=transform)

    train_iter = pytorch_trainer.iterators.SerialIterator(
        train, args.batchsize)
    test_iter = pytorch_trainer.iterators.SerialIterator(test,
                                                         args.batchsize,
                                                         repeat=False,
                                                         shuffle=False)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=device),
                   call_before_training=True)

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    # trainer.extend(extensions.DumpGraph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    # Take a snapshot each ``frequency`` epoch, delete old stale
    # snapshots and automatically load from snapshot files if any
    # files are already resident at result directory.
    trainer.extend(extensions.snapshot(n_retains=1, autoload=args.autoload),
                   trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(), call_before_training=True)

    # Save two plot images to the result dir
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                         'epoch',
                                         file_name='loss.png'),
                   call_before_training=True)
    trainer.extend(extensions.PlotReport(
        ['main/accuracy', 'validation/main/accuracy'],
        'epoch',
        file_name='accuracy.png'),
                   call_before_training=True)

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport([
        'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
        'validation/main/accuracy', 'elapsed_time'
    ]),
                   call_before_training=True)

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume is not None:
        # Resume from a snapshot (Note: this loaded model is to be
        # overwritten by --autoload option, autoloading snapshots, if
        # any snapshots exist in output directory)
        trainer.load_state_dict(torch.load(args.resume))

    # Run the training
    trainer.run()
def train_phase(generator, train, valid, args):

    print('# samples:')
    print('-- train:', len(train))
    print('-- valid:', len(valid))

    # setup dataset iterators
    train_iter = iterators.SerialIterator(train, args.batchsize)
    valid_iter = iterators.SerialIterator(valid, args.batchsize,
                                                repeat=False, shuffle=True)

    # setup a model
    model = Regressor(generator,
                      activation=torch.tanh,
                      lossfun=F.l1_loss,
                      accfun=F.l1_loss)

    discriminator = build_discriminator()
    discriminator.save_args(os.path.join(args.out, 'discriminator.json'))

    device = torch.device(args.gpu)

    model.to(device)
    discriminator.to(device)

    # setup an optimizer
    optimizer_G = torch.optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta, 0.999),
                                   weight_decay=max(args.decay, 0))

    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.beta, 0.999),
                                   weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = DCGANUpdater(
        iterator=train_iter,
        optimizer={
            'gen': optimizer_G,
            'dis': optimizer_D,
        },
        model={
            'gen': model,
            'dis': discriminator,
        },
        alpha=args.alpha,
        device=args.gpu,
    )

    frequency = max(args.iteration//80, 1) if args.frequency == -1 else max(1, args.frequency)

    stop_trigger = triggers.EarlyStoppingTrigger(monitor='validation/main/loss',
                        max_trigger=(args.iteration, 'iteration'),
                        check_trigger=(frequency, 'iteration'),
                        patients=np.inf if args.pinfall == -1 else max(1, args.pinfall))

    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # shift lr
    trainer.extend(
        extensions.LinearShift('lr', (args.lr, 0.0),
                        (args.iteration//2, args.iteration),
                        optimizer=optimizer_G))
    trainer.extend(
        extensions.LinearShift('lr', (args.lr, 0.0),
                        (args.iteration//2, args.iteration),
                        optimizer=optimizer_D))

    # setup a visualizer

    transforms = {'x': lambda x: x, 'y': lambda x: x, 't': lambda x: x}
    clims = {'x': (-1., 1.), 'y': (-1., 1.), 't': (-1., 1.)}

    visualizer = ImageVisualizer(transforms=transforms,
                                 cmaps=None,
                                 clims=clims)

    # setup a validator
    valid_file = os.path.join('validation', 'iter_{.updater.iteration:08}.png')
    trainer.extend(Validator(valid_iter, model, valid_file,
                             visualizer=visualizer, n_vis=20,
                             device=args.gpu),
                             trigger=(frequency, 'iteration'))

    # trainer.extend(DumpGraph('loss_gen', filename='generative_loss.dot'))
    # trainer.extend(DumpGraph('loss_cond', filename='conditional_loss.dot'))
    # trainer.extend(DumpGraph('loss_dis', filename='discriminative_loss.dot'))

    trainer.extend(extensions.snapshot(filename='snapshot_iter_{.updater.iteration:08}.pth'),
                                       trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(generator, 'generator_iter_{.updater.iteration:08}.pth'),
                                              trigger=(frequency, 'iteration'))
    trainer.extend(extensions.snapshot_object(discriminator, 'discriminator_iter_{.updater.iteration:08}.pth'),
                                              trigger=(frequency, 'iteration'))

    log_keys = ['loss_gen', 'loss_cond', 'loss_dis',
                'validation/main/accuracy']

    trainer.extend(LogReport(keys=log_keys, trigger=(100, 'iteration')))

    # setup log ploter
    if extensions.PlotReport.available():
        for plot_key in ['loss', 'accuracy']:
            plot_keys = [key for key in log_keys if key.split('/')[-1].startswith(plot_key)]
            trainer.extend(
                extensions.PlotReport(plot_keys,
                                     'iteration', file_name=plot_key + '.png',
                                     trigger=(frequency, 'iteration')) )

    trainer.extend(PrintReport(['iteration'] + log_keys + ['elapsed_time'], n_step=1))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))


    # train
    trainer.run()