示例#1
0
def main():
    train_x, train_y, val_x, val_y = load_pascal_voc_dataset(DATASET_ROOT)
    train_dataset = YoloDataset(train_x,
                                train_y,
                                target_size=model_class.img_size,
                                n_grid=model_class.n_grid,
                                augment=True)
    test_dataset = YoloDataset(val_x,
                               val_y,
                               target_size=model_class.img_size,
                               n_grid=model_class.n_grid,
                               augment=False)

    class_weights = [1.0 for i in range(train_dataset.n_classes)]
    class_weights[0] = 0.2
    model = model_class(n_classes=train_dataset.n_classes,
                        n_base_units=6,
                        class_weights=class_weights)
    if os.path.exists(RESULT_DIR + '/model_last.npz'):
        print('continue from previous result')
        chainer.serializers.load_npz(RESULT_DIR + '/model_last.npz', model)
    optimizer = Adam()
    optimizer.setup(model)

    train_iter = SerialIterator(train_dataset, batch_size=BATCH_SIZE)
    test_iter = SerialIterator(test_dataset,
                               batch_size=BATCH_SIZE,
                               shuffle=False,
                               repeat=False)
    updater = StandardUpdater(train_iter, optimizer, device=DEVICE)
    trainer = Trainer(updater, (N_EPOCHS, 'epoch'), out=RESULT_DIR)

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.Evaluator(test_iter, model, device=DEVICE))
    trainer.extend(
        extensions.PrintReport([
            'main/loss',
            'validation/main/loss',
            'main/cl_loss',
            'validation/main/cl_loss',
            'main/cl_acc',
            'validation/main/cl_acc',
            'main/pos_loss',
            'validation/main/pos_loss',
        ]))
    trainer.extend(extensions.snapshot_object(model, 'best_loss.npz'),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(extensions.snapshot_object(model,
                                              'best_classification.npz'),
                   trigger=triggers.MaxValueTrigger('validation/main/cl_acc'))
    trainer.extend(
        extensions.snapshot_object(model, 'best_position.npz'),
        trigger=triggers.MinValueTrigger('validation/main/pos_loss'))
    trainer.extend(extensions.snapshot_object(model, 'model_last.npz'),
                   trigger=(1, 'epoch'))

    trainer.run()
示例#2
0
trainer = training.Trainer(updater, (max_epoch, 'epoch'),
                           out='subject_classify_result_1')

# ### 训练器扩展

# In[12]:

# 收集每一次迭代的loss和accuracy并保存
trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
# 保存trainer(包括Updater和Optimizer)
# trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'), trigger=(10, 'epoch'))
# 保存部分trainer(仅model?)
trainer.extend(
    extensions.snapshot_object(model.predictor,
                               filename='model_epoch-{.updater.epoch}'),
    trigger=triggers.MaxValueTrigger('validation/main/accuracy', (1, 'epoch')))
# 保存计算图
trainer.extend(extensions.dump_graph('main/loss'))
# 评估
trainer.extend(extensions.Evaluator(test_iter, model), trigger=(1, 'epoch'))
# 打印输出
trainer.extend(extensions.PrintReport([
    'epoch', 'main/loss', 'main/accuracy', 'validation/main/loss',
    'validation/main/accuracy', 'elapsed_time'
]),
               trigger=(1, 'epoch'))
# 图像形式保存输出
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                     x_key='epoch',
                                     file_name='loss.png'),
               trigger=(1, 'epoch'))
def main():

    args = parse_args()
    res = Resource(args, train=True)

    train, test, train_gt, test_gt = load_train_test(
        train_dir=const.PREPROCESSED_TRAIN_DIR, gt_dir=const.XML_DIR)
    res.log_info(f'Train: {len(train)}, test: {len(test)}')

    model = ARCHS[args.model](n_fg_class=len(const.LABELS),
                              pretrained_model='imagenet')
    model.use_preset('evaluate')
    train_chain = MultiboxTrainChain(model)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    train_dataset = TransformDataset(
        ISIC2018Task1Dataset(train, train_gt),
        Transform(model.coder, model.insize, model.mean))
    train_iter = chainer.iterators.MultithreadIterator(
        train_dataset, args.batchsize, n_threads=args.loaderjob)

    test_dataset = TransformDataset(
        ISIC2018Task1Dataset(test, test_gt),
        Transform(model.coder, model.insize, model.mean))
    test_iter = chainer.iterators.MultithreadIterator(test_dataset,
                                                      args.batchsize,
                                                      shuffle=False,
                                                      repeat=False,
                                                      n_threads=args.loaderjob)

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(train_chain)
    for param in train_chain.params():
        if param.name == 'b':
            param.update_rule.add_hook(GradientScaling(2))
        else:
            param.update_rule.add_hook(WeightDecay(0.0005))

    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out)
    trainer.extend(
        DetectionVOCEvaluator(test_iter,
                              model,
                              use_07_metric=False,
                              label_names=const.LABELS))

    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'main/loss/loc',
            'main/loss/conf', 'validation/main/map'
        ]))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    snapshot_trigger = triggers.MaxValueTrigger(key='validation/main/map')
    snapshot_object_trigger = triggers.MaxValueTrigger(
        key='validation/main/map')
    trainer.extend(extensions.snapshot(filename='snapshot_best.npz'),
                   trigger=snapshot_trigger)
    trainer.extend(extensions.snapshot_object(model, 'model_best.npz'),
                   trigger=snapshot_object_trigger)

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()

    # save last model
    chainer.serializers.save_npz(os.path.join(args.out, 'snapshot_last.npz'),
                                 trainer)
    chainer.serializers.save_npz(os.path.join(args.out, 'model_last.npz'),
                                 model)
示例#4
0
def main():
    args = parse_args()
    with open(args.config_path) as f:
        config = json.load(f)
    with open(args.app_config) as f:
        app_config = json.load(f)
    app_train_config = app_config.get('train', {})
    command_config = {'gpu': args.gpu}
    if args.output_dir is not None:
        command_config['output_dir'] = args.output_dir
    config.update(command_config)

    device_id = config['gpu']
    batch_size = config['batch_size']

    network_params = config['network']
    nets = {k: util.create_network(v) for k, v in network_params.items()}
    optimizers = {
        k: util.create_optimizer(v['optimizer'], nets[k])
        for k, v in network_params.items()
    }
    if len(optimizers) == 1:
        key, target_optimizer = list(optimizers.items())[0]
        target = nets[key]
    else:
        target = nets
        target_optimizer = optimizers

    if device_id >= 0:
        chainer.cuda.get_device_from_id(device_id).use()
        for net in nets.values():
            net.to_gpu()

    datasets = dataset.get_dataset()
    iterators = {}
    if isinstance(datasets, dict):
        for name, data in datasets.items():
            if name == 'train':
                train_iterator = chainer.iterators.SerialIterator(
                    data, batch_size)
            else:
                iterators[name] = chainer.iterators.SerialIterator(
                    data, batch_size, repeat=False, shuffle=False)
    else:
        train_iterator = chainer.iterators.SerialIterator(datasets, batch_size)
    updater = TrainingStep(train_iterator,
                           target_optimizer,
                           model.calculate_metrics,
                           device=device_id)
    trainer = Trainer(updater, (config['epoch'], 'epoch'),
                      out=config['output_dir'])
    if hasattr(model, 'make_eval_func'):
        for name, iterator in iterators.items():
            evaluator = extensions.Evaluator(
                iterator,
                target,
                eval_func=model.make_eval_func(target),
                device=device_id)
            trainer.extend(evaluator, name=name)

    dump_graph_node = app_train_config.get('dump_graph', None)
    if dump_graph_node is not None:
        trainer.extend(extensions.dump_graph(dump_graph_node))

    trainer.extend(extensions.snapshot(filename='snapshot.state'),
                   trigger=(1, 'epoch'))
    for k, net in nets.items():
        file_name = 'latest.{}.model'.format(k)
        trainer.extend(extensions.snapshot_object(net, filename=file_name),
                       trigger=(1, 'epoch'))
    max_value_trigger_key = app_train_config.get('max_value_trigger', None)
    min_value_trigger_key = app_train_config.get('min_value_trigger', None)
    if max_value_trigger_key is not None:
        trigger = triggers.MaxValueTrigger(max_value_trigger_key)
        for key, net in nets.items():
            file_name = 'best.{}.model'.format(key)
            trainer.extend(extensions.snapshot_object(net, filename=file_name),
                           trigger=trigger)
    elif min_value_trigger_key is not None:
        trigger = triggers.MinValueTrigger(min_value_trigger_key)
        for key, net in nets.items():
            file_name = 'best.{}.model'.format(key)
            trainer.extend(extensions.snapshot_object(net, file_name),
                           trigger=trigger)
    trainer.extend(extensions.LogReport())
    if len(optimizers) == 1:
        for name, opt in optimizers.items():
            if not hasattr(opt, 'lr'):
                continue
            trainer.extend(extensions.observe_lr(name))
    else:
        for name, opt in optimizers.items():
            if not hasattr(opt, 'lr'):
                continue
            key = '{}/lr'.format(name)
            trainer.extend(extensions.observe_lr(name, key))

    if extensions.PlotReport.available():
        plot_targets = app_train_config.get('plot_report', {})
        for name, targets in plot_targets.items():
            file_name = '{}.png'.format(name)
            trainer.extend(
                extensions.PlotReport(targets, 'epoch', file_name=file_name))

    if not args.silent:
        print_targets = app_train_config.get('print_report', [])
        if print_targets is not None and print_targets != []:
            trainer.extend(extensions.PrintReport(print_targets))
        trainer.extend(extensions.ProgressBar())

    trainer.extend(generate_image(nets['gen'], 10, 10,
                                  config['output_image_dir']),
                   trigger=(1, 'epoch'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
示例#5
0
def create_extension(trainer, test_iter, model, config, devices=None):
    """Create extension for training models"""
    for key, ext in config.items():
        if key == "Evaluator":
            cl = get_class(ext['module'])
            Evaluator = getattr(cl, ext['name'])
            trigger = parse_trigger(ext['trigger'])
            args = parse_dict(ext, 'args', {})
            if parse_dict(args, 'label_names', 'voc') == 'voc':
                args['label_names'] = voc_bbox_label_names
            trainer.extend(Evaluator(test_iter, model, **args),
                           trigger=trigger)
        elif key == "dump_graph":
            cl = getattr(extensions, key)
            trainer.extend(cl(ext['name']))
        elif key == 'snapshot':
            cl = getattr(extensions, key)
            trigger = parse_trigger(ext['trigger'])
            trainer.extend(cl(), trigger=trigger)
        elif key == 'snapshot_object':
            cl = getattr(extensions, key)
            trigger = parse_trigger(ext['trigger'])
            args = parse_dict(ext, 'args', {})
            if args:
                if args['method'] == 'best':
                    trigger = triggers.MaxValueTrigger(args['name'], trigger)
            trainer.extend(cl(model, 'yolov2_{.updater.iteration}'),
                           trigger=trigger)
        elif key == 'LogReport':
            cl = getattr(extensions, key)
            trigger = parse_trigger(ext['trigger'])
            trainer.extend(cl(trigger=trigger))
        elif key == "PrintReport":
            cl = getattr(extensions, key)
            report_list = ext['name'].split(' ')
            trigger = parse_trigger(ext['trigger'])
            trainer.extend(cl(report_list), trigger=trigger)
        elif key == "ProgressBar":
            cl = getattr(extensions, key)
            trainer.extend(cl(update_interval=ext['update_interval']))
        elif key == 'observe_lr':
            cl = getattr(extensions, key)
            trigger = parse_trigger(ext['trigger'])
            trainer.extend(cl(), trigger=trigger)
        elif key == "PolynomialShift":
            cl = getattr(lr_utils, key)
            trigger = parse_trigger(ext['trigger'])
            len_dataset = len(trainer.updater.get_iterator('main').dataset)
            batchsize = trainer.updater.get_iterator('main').batch_size
            args = parse_dict(ext, 'args', {})
            args.update({
                'len_dataset': len_dataset,
                'batchsize': batchsize,
                'stop_trigger': trainer.stop_trigger
            })
            trainer.extend(cl(**args))
        elif key == "DarknetLRScheduler":
            cl = getattr(lr_utils, key)
            args = parse_dict(ext, 'args', {})
            args['step_trigger'] = [int(num) for num in args['step_trigger']]
            trainer.extend(cl(**args))
        elif key == "ExponentialShift":
            cl = getattr(extensions, key)
            attr = ext['attr']
            rate = ext['rate']
            name = ext['name']
            numbers = [int(num) for num in ext['numbers']]
            trainer.extend(cl(attr, rate),
                           trigger=triggers.ManualScheduleTrigger(
                               numbers, name))

    return trainer
示例#6
0
def main():
    # コマンドライン引数の読み込み
    parser = argparse.ArgumentParser(description='Chainer MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=20, help='Batch size')
    parser.add_argument('--epoch'    , '-e', type=int, default=20, help='Epoch')
    parser.add_argument('--gpu'      , '-g', type=int, default=-1, help='GPU ID')
    parser.add_argument('--out'      , '-o', default='result', help='output directory')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # reading model
    model = L.Classifier(CNN(), lossfun=F.softmax_cross_entropy)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    # adam optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    # loading MNIST dataset
    train, test = chainer.datasets.get_mnist(ndim=3)

    # Iterator of dataset with Batchsize
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize) # for training
    test_iter  = chainer.iterators.SerialIterator(test,  args.batchsize, repeat=False, shuffle=False)

    # updater/trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # setup evaluator
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # plotting mnist-cnn network
    trainer.extend(extensions.dump_graph('main/loss'))

    # Reporting
    # setup log
    trainer.extend(extensions.LogReport())

    # progress plot
    if extensions.PlotReport.available():
        trainer.extend(extensions.PlotReport(
            ['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')
        )
        trainer.extend(extensions.PlotReport(
            ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')
        )

    # progress console
    trainer.extend(extensions.PrintReport(
            ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])
        )

    # Saving at updated test-accuracy
    trigger = triggers.MaxValueTrigger('validation/main/accuracy', trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model, filename='mnist-cnn-best'), trigger=trigger)

    # progress bar
    trainer.extend(extensions.ProgressBar())

    # Training
    trainer.run()

    # Saving model final
    serializers.save_npz('mnist-cnn.npz', model)
def train(outdir, args):

    #######################
    # initialize settings #
    #######################
    vocab = read_vocabfile(args.vocab_path)
    n_epochs = args.epoch
    lr = args.lr

    #################
    # get iterators #
    #################
    if args.dataset == "PE":
        essay_info_dict, essay_max_n_dict, para_info_dict = \
            get_data_dicts(vocab, args)
        train_ids, dev_ids, test_ids = \
            return_train_dev_test_ids_PE(vocab,
                                         essay_info_dict,
                                         essay_max_n_dict,
                                         para_info_dict,
                                         args,
                                         dev_shuffle=True)

        train_iter, dev_iter, test_iter = \
            return_train_dev_test_iter_PE(train_ids,
                                          dev_ids,
                                          test_ids,
                                          essay_info_dict,
                                          essay_max_n_dict,
                                          args)

    elif args.dataset == "MT":
        train_iter, dev_iter, test_iter, essay_info_dict, \
            essay_max_n_dict, para_info_dict = \
            return_train_dev_test_iter_MT(
                                         vocab,
                                         args,
                                         args.iteration,
                                         args.fold)

    max_n_spans_para = essay_max_n_dict["max_n_spans_para"]
    max_n_paras = essay_max_n_dict["max_n_paras"]
    max_n_tokens = essay_max_n_dict["max_n_tokens"]

    ################
    # select model #
    ################
    predictor = SpanSelectionParser(vocab=vocab,
                                    essay_info_dict=essay_info_dict,
                                    para_info_dict=para_info_dict,
                                    max_n_spans_para=max_n_spans_para,
                                    max_n_paras=max_n_paras,
                                    max_n_tokens=max_n_tokens,
                                    settings=args,
                                    baseline_heuristic=args.baseline_heuristic,
                                    use_elmo=args.use_elmo,
                                    decoder=args.decoder)

    sys.stderr.write("dump setting file...\n")
    predictor.to_cpu()

    dill.dump(predictor, open(outdir + "/model.setting", "wb"))

    model = FscoreClassifier(predictor,
                             max_n_spans_para,
                             args,
                             lossfun=softmax_cross_entropy_flatten,
                             accfun=accuracy_flatten,
                             fscore_target_fun=classification_summary_flatten,
                             fscore_link_fun=fscore_binary,
                             count_prediction=count_prediction,
                             ac_type_alpha=args.ac_type_alpha,
                             link_type_alpha=args.link_type_alpha)

    if args.gpu_id >= 0:
        # Specify GPU ID from command line
        model.to_gpu()

    #####################
    # loading embedding #
    #####################
    load_embeddings(model, vocab, args)

    #############
    # optimizer #
    #############
    if args.optimizer == "Adam":
        optimizer = optimizers.Adam(alpha=lr)
    else:
        raise NotImplementedError

    optimizer.setup(model)

    ##############################
    # training iteration setting #
    ##############################
    updater = training.StandardUpdater(copy.deepcopy(train_iter),
                                       optimizer,
                                       device=args.gpu_id,
                                       converter=convert)

    trainer = training.Trainer(updater, (n_epochs, 'epoch'), out=outdir)

    ##############
    # extensions #
    ##############
    chainer.training.triggers.IntervalTrigger(period=1, unit='epoch')

    trainer.extend(extensions.Evaluator(copy.deepcopy(dev_iter),
                                        model,
                                        converter=convert,
                                        device=args.gpu_id),
                   name='validation')
    trainer.extend(extensions.Evaluator(copy.deepcopy(test_iter),
                                        model,
                                        converter=convert,
                                        device=args.gpu_id),
                   name='test')
    trainer.extend(extensions.ProgressBar(update_interval=1))
    trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/macro_f_link',
            'validation/main/macro_f_link_type', 'validation/main/macro_f_type'
        ]))
    #'elapsed_time']))

    trainer.extend(
        log_current_score(trigger=triggers.MaxValueTrigger(key=args.monitor,
                                                           trigger=(1,
                                                                    'epoch')),
                          log_report='LogReport',
                          out_dir=outdir))

    if args.log_output:
        trainer.extend(
            log_current_output(trigger=triggers.MaxValueTrigger(
                key=args.monitor, trigger=(1, 'epoch')),
                               test_iter=copy.deepcopy(test_iter).next(),
                               max_n_spans_para=max_n_spans_para,
                               log_report='LogReport',
                               out=outdir,
                               settings=args))

    trainer.extend(extensions.snapshot_object(model, 'model.best'),
                   trigger=chainer.training.triggers.MaxValueTrigger(
                       key=args.monitor, trigger=(1, 'epoch')))

    trainer.run()
示例#8
0
def train(model_class,
          n_base_units,
          trained_model,
          no_obj_weight,
          data,
          result_dir,
          initial_batch_size=10,
          max_batch_size=1000,
          max_epoch=100):
    train_x, train_y, val_x, val_y = data

    max_class_id = 0
    for objs in val_y:
        for obj in objs:
            max_class_id = max(max_class_id, obj[4])
    n_classes = max_class_id + 1

    class_weights = [1.0 for i in range(n_classes)]
    class_weights[0] = no_obj_weight
    train_dataset = YoloDataset(train_x,
                                train_y,
                                target_size=model_class.img_size,
                                n_grid=model_class.n_grid,
                                augment=True,
                                class_weights=class_weights)
    test_dataset = YoloDataset(val_x,
                               val_y,
                               target_size=model_class.img_size,
                               n_grid=model_class.n_grid,
                               augment=False,
                               class_weights=class_weights)

    model = model_class(n_classes, n_base_units)
    model.loss_calc = LossCalculator(n_classes, class_weights=class_weights)

    last_result_file = os.path.join(result_dir, 'best_loss.npz')
    if os.path.exists(last_result_file):
        try:
            chainer.serializers.load_npz(last_result_file, model)
            print('this training has done. resuse the result')
            return model
        except:
            pass

    if trained_model:
        print('copy params from trained model')
        copy_params(trained_model, model)

    optimizer = Adam()
    optimizer.setup(model)

    n_physical_cpu = int(math.ceil(multiprocessing.cpu_count() / 2))

    train_iter = MultiprocessIterator(train_dataset,
                                      batch_size=initial_batch_size,
                                      n_prefetch=n_physical_cpu,
                                      n_processes=n_physical_cpu)
    test_iter = MultiprocessIterator(test_dataset,
                                     batch_size=initial_batch_size,
                                     shuffle=False,
                                     repeat=False,
                                     n_prefetch=n_physical_cpu,
                                     n_processes=n_physical_cpu)
    updater = StandardUpdater(train_iter, optimizer, device=0)
    stopper = triggers.EarlyStoppingTrigger(check_trigger=(1, 'epoch'),
                                            monitor="validation/main/loss",
                                            patients=10,
                                            mode="min",
                                            max_trigger=(max_epoch, "epoch"))
    trainer = Trainer(updater, stopper, out=result_dir)

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.Evaluator(test_iter, model, device=0))
    trainer.extend(
        extensions.PrintReport([
            'epoch',
            'main/loss',
            'validation/main/loss',
            'main/cl_loss',
            'validation/main/cl_loss',
            'main/cl_acc',
            'validation/main/cl_acc',
            'main/pos_loss',
            'validation/main/pos_loss',
        ]))
    trainer.extend(extensions.snapshot_object(model, 'best_loss.npz'),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(extensions.snapshot_object(model,
                                              'best_classification.npz'),
                   trigger=triggers.MaxValueTrigger('validation/main/cl_acc'))
    trainer.extend(
        extensions.snapshot_object(model, 'best_position.npz'),
        trigger=triggers.MinValueTrigger('validation/main/pos_loss'))
    trainer.extend(extensions.snapshot_object(model, 'model_last.npz'),
                   trigger=(1, 'epoch'))
    trainer.extend(AdaptiveBatchsizeIncrement(maxsize=max_batch_size),
                   trigger=(1, 'epoch'))

    trainer.run()

    chainer.serializers.load_npz(os.path.join(result_dir, 'best_loss.npz'),
                                 model)
    return model
示例#9
0
def main(args=None):
    set_random_seed(63)
    chainer.global_config.autotune = True
    chainer.cuda.set_max_workspace_size(512 * 1024 * 1024)
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--learnrate',
                        '-l',
                        type=float,
                        default=0.01,
                        help='Learning rate for SGD')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=80,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--loss-function',
                        choices=['focal', 'sigmoid'],
                        default='focal')
    parser.add_argument('--optimizer',
                        choices=['sgd', 'adam', 'adabound'],
                        default='adam')
    parser.add_argument('--size', type=int, default=224)
    parser.add_argument('--limit', type=int, default=None)
    parser.add_argument('--data-dir', type=str, default='data')
    parser.add_argument('--lr-search', action='store_true')
    parser.add_argument('--pretrained', type=str, default='')
    parser.add_argument('--backbone',
                        choices=['resnet', 'seresnet', 'debug_model'],
                        default='resnet')
    parser.add_argument('--log-interval', type=int, default=100)
    parser.add_argument('--find-threshold', action='store_true')
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    args = parser.parse_args() if args is None else parser.parse_args(args)

    print(args)

    if args.mixup and args.loss_function != 'focal':
        raise ValueError('mixupを使うときはfocal lossしか使えません(いまんところ)')

    train, test, cooccurrence = get_dataset(args.data_dir, args.size,
                                            args.limit, args.mixup)
    base_model = backbone_catalog[args.backbone](args.dropout)

    if args.pretrained:
        print('loading pretrained model: {}'.format(args.pretrained))
        chainer.serializers.load_npz(args.pretrained, base_model, strict=False)
    model = TrainChain(base_model,
                       1,
                       loss_fn=args.loss_function,
                       cooccurrence=cooccurrence,
                       co_coef=0)
    if args.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    if args.optimizer in ['adam', 'adabound']:
        optimizer = Adam(alpha=args.learnrate,
                         adabound=args.optimizer == 'adabound',
                         weight_decay_rate=1e-5,
                         gamma=5e-7)
    elif args.optimizer == 'sgd':
        optimizer = chainer.optimizers.MomentumSGD(lr=args.learnrate)

    optimizer.setup(model)

    if not args.finetune:
        print('最初のエポックは特徴抽出層をfreezeします')
        model.freeze_extractor()

    train_iter = chainer.iterators.MultiprocessIterator(train,
                                                        args.batchsize,
                                                        n_processes=8,
                                                        n_prefetch=2)
    test_iter = chainer.iterators.MultithreadIterator(test,
                                                      args.batchsize,
                                                      n_threads=8,
                                                      repeat=False,
                                                      shuffle=False)

    if args.find_threshold:
        # train_iter, optimizerなど無駄なsetupもあるが。。
        print('thresholdを探索して終了します')
        chainer.serializers.load_npz(join(args.out, 'bestmodel_loss'),
                                     base_model)
        print('lossがもっとも小さかったモデルに対しての結果:')
        find_threshold(base_model, test_iter, args.gpu, args.out)

        chainer.serializers.load_npz(join(args.out, 'bestmodel_f2'),
                                     base_model)
        print('f2がもっとも大きかったモデルに対しての結果:')
        find_threshold(base_model, test_iter, args.gpu, args.out)
        return

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter,
        optimizer,
        device=args.gpu,
        converter=lambda batch, device: chainer.dataset.concat_examples(
            batch, device=device))
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(FScoreEvaluator(test_iter, model, device=args.gpu))

    if args.optimizer == 'sgd':
        # Adamにweight decayはあんまりよくないらしい
        optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(5e-4))
        trainer.extend(extensions.ExponentialShift('lr', 0.1),
                       trigger=(3, 'epoch'))
        if args.lr_search:
            print('最適な学習率を探します')
            trainer.extend(LRFinder(1e-7, 1, 5, optimizer),
                           trigger=(1, 'iteration'))
    elif args.optimizer in ['adam', 'adabound']:
        if args.lr_search:
            print('最適な学習率を探します')
            trainer.extend(LRFinder(1e-7, 1, 5, optimizer, lr_key='alpha'),
                           trigger=(1, 'iteration'))

        trainer.extend(extensions.ExponentialShift('alpha', 0.2),
                       trigger=triggers.EarlyStoppingTrigger(
                           monitor='validation/main/loss'))

    # Take a snapshot of Trainer at each epoch
    trainer.extend(
        extensions.snapshot(filename='snaphot_epoch_{.updater.epoch}'),
        trigger=(10, 'epoch'))

    # Take a snapshot of Model which has best val loss.
    # Because searching best threshold for each evaluation takes too much time.
    trainer.extend(extensions.snapshot_object(model.model, 'bestmodel_loss'),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(extensions.snapshot_object(model.model, 'bestmodel_f2'),
                   trigger=triggers.MaxValueTrigger('validation/main/f2'))
    trainer.extend(extensions.snapshot_object(model.model,
                                              'model_{.updater.epoch}'),
                   trigger=(5, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.log_interval, 'iteration')))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'lr', 'elapsed_time', 'main/loss', 'main/co_loss',
            'validation/main/loss', 'validation/main/co_loss',
            'validation/main/precision', 'validation/main/recall',
            'validation/main/f2', 'validation/main/threshold'
        ]))

    trainer.extend(extensions.ProgressBar(update_interval=args.log_interval))
    trainer.extend(extensions.observe_lr(),
                   trigger=(args.log_interval, 'iteration'))
    trainer.extend(CommandsExtension())
    save_args(args, args.out)

    trainer.extend(lambda trainer: model.unfreeze_extractor(),
                   trigger=(1, 'epoch'))

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # save args with pickle for prediction time
    pickle.dump(args, open(str(Path(args.out).joinpath('args.pkl')), 'wb'))

    # Run the training
    trainer.run()

    # find optimal threshold
    chainer.serializers.load_npz(join(args.out, 'bestmodel_loss'), base_model)
    print('lossがもっとも小さかったモデルに対しての結果:')
    find_threshold(base_model, test_iter, args.gpu, args.out)

    chainer.serializers.load_npz(join(args.out, 'bestmodel_f2'), base_model)
    print('f2がもっとも大きかったモデルに対しての結果:')
    find_threshold(base_model, test_iter, args.gpu, args.out)
示例#10
0
def main():
    start_time = time.time()
    ap = ArgumentParser(description='python train_cc.py')
    ap.add_argument('--indir', '-i', nargs='?', default='datasets/train', help='Specify input files directory for learning data')
    ap.add_argument('--outdir', '-o', nargs='?', default='results/results_training_cc', help='Specify output files directory for create save model files')
    ap.add_argument('--train_list', nargs='?', default='datasets/split_list/train.list', help='Specify split train list')
    ap.add_argument('--validation_list', nargs='?', default='datasets/split_list/validation.list', help='Specify split validation list')
    ap.add_argument('--init_model', help='Specify Loading File Path of Learned Cell Classification Model')
    ap.add_argument('--gpu', '-g', type=int, default=-1, help='Specify GPU ID (negative value indicates CPU)')
    ap.add_argument('--epoch', '-e', type=int, default=10, help='Specify number of sweeps over the dataset to train')
    ap.add_argument('--batchsize', '-b', type=int, default=5, help='Specify Batchsize')
    ap.add_argument('--crop_size', nargs='?', default='(640, 640)', help='Specify crop size (default (y,x) = (640,640))')
    ap.add_argument('--coordinate', nargs='?', default='(780, 1480)', help='Specify initial coordinate (default (y,x) = (1840,700))')
    ap.add_argument('--nclass', type=int, default=10, help='Specify classification class')

    args = ap.parse_args()
    argvs = sys.argv
    psep = '/'

    print('init dataset...')
    train_dataset = PreprocessedClassificationDataset(
        path=args.indir,
        split_list=args.train_list,
        crop_size=args.crop_size,
        coordinate=args.coordinate,
        train=True
    )
    validation_dataset = PreprocessedClassificationDataset(
        path=args.indir,
        split_list=args.validation_list,
        crop_size=args.crop_size,
        coordinate=args.coordinate,
        train=False
    )

    print('init model construction')
    model = Classifier(
        CCNet(
            n_class=args.nclass
            ), lossfun=F.softmax_cross_entropy
        )

    if args.init_model is not None:
        print('Load model from', args.init_model)
        chainer.serializers.load_npz(args.init_model, model)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    print('init optimizer...')
    #optimizer = chainer.optimizers.Adam()
    optimizer = chainer.optimizers.SGD(lr=0.01)
    #optimizer = chainer.optimizers.MomentumSGD(lr=0.01)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(rate=0.0001))


    ''' Updater '''
    print('init updater')
    train_iter = chainer.iterators.SerialIterator(
        train_dataset, batch_size=args.batchsize)
    validation_iter = chainer.iterators.SerialIterator(
        validation_dataset, batch_size=1, repeat=False, shuffle=False)
    updater = chainer.training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)


    ''' Trainer '''
    current_datetime = datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y%m%d_%H%M%S')
    save_dir = args.outdir + '_' + str(current_datetime)
    os.makedirs(save_dir, exist_ok=True)
    trainer = training.Trainer(updater, stop_trigger=(args.epoch, 'epoch'), out=save_dir)

    '''
    Extensions:
        Evaluator         : Evaluate the segmentor with the validation dataset for each epoch
        ProgressBar       : print a progress bar and recent training status.
        ExponentialShift  : The typical use case is an exponential decay of the learning rate.
        dump_graph        : This extension dumps a computational graph.
        snapshot          : serializes the trainer object and saves it to the output directory
        snapshot_object   : serializes the given object and saves it to the output directory.
        LogReport         : output the accumulated results to a log file.
        PrintReport       : print the accumulated results.
        PlotReport        : output plots.
    '''

    evaluator = extensions.Evaluator(validation_iter, model, device=args.gpu)
    trainer.extend(evaluator, trigger=(1, 'epoch'))

    #trainer.extend(extensions.ExponentialShift('lr', 0.5), trigger=(50, 'epoch'))

    trigger = triggers.MaxValueTrigger('validation/main/accuracy', trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model, filename='best_acc_model'), trigger=trigger)

    trainer.extend(chainer.training.extensions.observe_lr(), trigger=(1, 'epoch'))

    # LogReport
    trainer.extend(
        extension=extensions.LogReport()
    )

    # PrintReport
    trainer.extend(
        extension=extensions.PrintReport([
            'epoch', 'iteration',
            'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy',
            'elapsed_time'])
    )

    # PlotReport
    trainer.extend(
        extension=extensions.PlotReport(
            ['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
    trainer.extend(
        extension=extensions.PlotReport(
            ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.run()
示例#11
0
def main():
    parser = argparse.ArgumentParser(
        description='Shake-shake resularization CIFAR10 w/ Chainer')
    parser.add_argument('--dataset',
                        '-d',
                        default='cifar10',
                        help='The dataset to use: cifar10 or cifar100')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=128,
                        help='Number of images in each mini-batch')
    parser.add_argument('--lr',
                        '-l',
                        type=float,
                        default=0.1,
                        help='Learning rate for SGD')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1800,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--base_width',
                        '-w',
                        type=int,
                        default=64,
                        help='Base width parameter for Shake-Shake model')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='run_0',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--nobar',
                        dest='bar',
                        action='store_false',
                        help='Disable ProgressBar extension')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    log_dir = os.path.join("results", args.out)
    writer = SummaryWriter(log_dir=log_dir)

    # Set up a neural network to train.
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    if args.dataset == 'cifar10':
        print('Using CIFAR10 dataset.')
        class_labels = 10
        train, test = cifar.get_cifar10(scale=255.)
    elif args.dataset == 'cifar100':
        print('Using CIFAR100 dataset.')
        class_labels = 100
        train, test = cifar.get_cifar100(scale=255.)
    else:
        raise RuntimeError('Invalid dataset choice.')

    # Data preprocess
    mean = np.mean([x for x, _ in train], axis=(0, 2, 3))
    std = np.std([x for x, _ in train], axis=(0, 2, 3))

    train_transfrom = partial(transform, mean=mean, std=std, train=True)
    test_transfrom = partial(transform, mean=mean, std=std, train=False)

    train = TransformDataset(train, train_transfrom)
    test = TransformDataset(test, test_transfrom)

    print('Finised data preparation. Preparing for model training...')

    # Model and optimizer configuration
    model = L.Classifier(ShakeResNet(class_labels, base_width=args.base_width))

    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    optimizer = chainer.optimizers.MomentumSGD(args.lr, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(1e-4))

    # Set up a trainer
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 args.batchsize,
                                                 repeat=False,
                                                 shuffle=False)

    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=log_dir)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # Decrease learning rate with cosine annealing
    trainer.extend(LrSceduler_CosineAnneal(args.lr, args.epoch))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot at each epoch
    trainer.extend(extensions.snapshot(filename='training_chekpoint'))

    # Take a snapshot of the current best model
    trigger_save_model = triggers.MaxValueTrigger('validation/main/accuracy')
    trainer.extend(extensions.snapshot_object(model, filename='best_model'),
                   trigger=trigger_save_model)

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Monitor learning rate at every iteration
    trainer.extend(extensions.observe_lr(), trigger=(1, 'iteration'))

    # Save two plot images to the result dir
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))

        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

        trainer.extend(
            extensions.PlotReport(['lr'], 'epoch', file_name='lr.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy', 'lr', 'elapsed_time'
        ]))

    if args.bar:
        # Print a progress bar to stdout
        trainer.extend(extensions.ProgressBar())

    # Write training log to TensorBoard log file
    trainer.extend(
        TensorboardLogger(writer, [
            'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy', 'lr'
        ]))

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    print("Finished preparation. Starting model training...")
    print()

    # Run the training
    trainer.run()
示例#12
0
文件: train.py 项目: diceroll/kmnist
def main():
    parser = argparse.ArgumentParser(description='training mnist')
    parser.add_argument('--gpu',
                        '-g',
                        default=-1,
                        type=int,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=300,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--seed',
                        '-s',
                        type=int,
                        default=0,
                        help='Random seed')
    parser.add_argument('--n_fold',
                        '-nf',
                        type=int,
                        default=5,
                        help='n_fold cross validation')
    parser.add_argument('--fold', '-f', type=int, default=1)
    parser.add_argument('--out_dir_name',
                        '-dn',
                        type=str,
                        default=None,
                        help='Name of the output directory')
    parser.add_argument('--report_trigger',
                        '-rt',
                        type=str,
                        default='1e',
                        help='Interval for reporting(Ex.100i, default:1e)')
    parser.add_argument('--save_trigger',
                        '-st',
                        type=str,
                        default='1e',
                        help='Interval for saving the model'
                        '(Ex.100i, default:1e)')
    parser.add_argument('--load_model',
                        '-lm',
                        type=str,
                        default=None,
                        help='Path of the model object to load')
    parser.add_argument('--load_optimizer',
                        '-lo',
                        type=str,
                        default=None,
                        help='Path of the optimizer object to load')
    args = parser.parse_args()

    if args.out_dir_name is None:
        start_time = datetime.now()
        out_dir = Path('output/{}'.format(start_time.strftime('%Y%m%d_%H%M')))
    else:
        out_dir = Path('output/{}'.format(args.out_dir_name))

    random.seed(args.seed)
    np.random.seed(args.seed)
    cupy.random.seed(args.seed)
    chainer.config.cudnn_deterministic = True

    # model = ModifiedClassifier(SEResNeXt50())
    # model = ModifiedClassifier(SERes2Net50())
    model = ModifiedClassifier(SEResNeXt101())

    if args.load_model is not None:
        serializers.load_npz(args.load_model, model)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    optimizer = optimizers.MomentumSGD(lr=0.1, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(1e-4))
    if args.load_optimizer is not None:
        serializers.load_npz(args.load_optimizer, optimizer)

    n_fold = args.n_fold
    slices = [slice(i, None, n_fold) for i in range(n_fold)]
    fold = args.fold - 1

    # model1
    # augmentation = [
    #     ('Rotate', {'p': 0.8, 'limit': 5}),
    #     ('PadIfNeeded', {'p': 0.5, 'min_height': 28, 'min_width': 30}),
    #     ('PadIfNeeded', {'p': 0.5, 'min_height': 30, 'min_width': 28}),
    #     ('Resize', {'p': 1.0, 'height': 28, 'width': 28}),
    #     ('RandomScale', {'p': 1.0, 'scale_limit': 0.1}),
    #     ('PadIfNeeded', {'p': 1.0, 'min_height': 32, 'min_width': 32}),
    #     ('RandomCrop', {'p': 1.0, 'height': 28, 'width': 28}),
    #     ('Mixup', {'p': 0.5}),
    #     ('Cutout', {'p': 0.5, 'num_holes': 4, 'max_h_size': 4,
    #                 'max_w_size': 4}),
    # ]
    # resize = None

    # model2
    augmentation = [
        ('Rotate', {
            'p': 0.8,
            'limit': 5
        }),
        ('PadIfNeeded', {
            'p': 0.5,
            'min_height': 28,
            'min_width': 32
        }),
        ('PadIfNeeded', {
            'p': 0.5,
            'min_height': 32,
            'min_width': 28
        }),
        ('Resize', {
            'p': 1.0,
            'height': 32,
            'width': 32
        }),
        ('RandomScale', {
            'p': 1.0,
            'scale_limit': 0.1
        }),
        ('PadIfNeeded', {
            'p': 1.0,
            'min_height': 36,
            'min_width': 36
        }),
        ('RandomCrop', {
            'p': 1.0,
            'height': 32,
            'width': 32
        }),
        ('Mixup', {
            'p': 0.5
        }),
        ('Cutout', {
            'p': 0.5,
            'num_holes': 4,
            'max_h_size': 4,
            'max_w_size': 4
        }),
    ]
    resize = [('Resize', {'p': 1.0, 'height': 32, 'width': 32})]

    train_data = KMNIST(augmentation=augmentation,
                        drop_index=slices[fold],
                        pseudo_labeling=True)
    valid_data = KMNIST(augmentation=resize, index=slices[fold])

    train_iter = iterators.SerialIterator(train_data, args.batchsize)
    valid_iter = iterators.SerialIterator(valid_data,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    updater = StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = Trainer(updater, (args.epoch, 'epoch'), out=out_dir)

    report_trigger = (int(args.report_trigger[:-1]), 'iteration'
                      if args.report_trigger[-1] == 'i' else 'epoch')
    trainer.extend(extensions.LogReport(trigger=report_trigger))
    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu),
                   name='val',
                   trigger=report_trigger)
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'main/accuracy', 'val/main/loss',
        'val/main/accuracy', 'elapsed_time'
    ]),
                   trigger=report_trigger)
    trainer.extend(
        extensions.PlotReport(['main/loss', 'val/main/loss'],
                              x_key=report_trigger[1],
                              marker='.',
                              file_name='loss.png',
                              trigger=report_trigger))
    trainer.extend(
        extensions.PlotReport(['main/accuracy', 'val/main/accuracy'],
                              x_key=report_trigger[1],
                              marker='.',
                              file_name='accuracy.png',
                              trigger=report_trigger))

    save_trigger = (int(args.save_trigger[:-1]),
                    'iteration' if args.save_trigger[-1] == 'i' else 'epoch')
    trainer.extend(extensions.snapshot_object(
        model,
        filename='model_{0}-{{.updater.{0}}}.npz'.format(save_trigger[1])),
                   trigger=save_trigger)
    trainer.extend(extensions.snapshot_object(
        optimizer,
        filename='optimizer_{0}-{{.updater.{0}}}.npz'.format(save_trigger[1])),
                   trigger=save_trigger)
    trainer.extend(extensions.ProgressBar())
    trainer.extend(CosineAnnealing(lr_max=0.1, lr_min=1e-6, T_0=20),
                   trigger=(1, 'epoch'))

    best_model_trigger = triggers.MaxValueTrigger('val/main/accuracy',
                                                  trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model,
                                              filename='best_model.npz'),
                   trigger=best_model_trigger)
    trainer.extend(extensions.snapshot_object(optimizer,
                                              filename='best_optimizer.npz'),
                   trigger=best_model_trigger)
    best_loss_model_trigger = triggers.MinValueTrigger('val/main/loss',
                                                       trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model,
                                              filename='best_loss_model.npz'),
                   trigger=best_loss_model_trigger)
    trainer.extend(extensions.snapshot_object(
        optimizer, filename='best_loss_optimizer.npz'),
                   trigger=best_loss_model_trigger)

    if out_dir.exists():
        shutil.rmtree(out_dir)
    out_dir.mkdir()

    # Write parameters text
    with open(out_dir / 'train_params.txt', 'w') as f:
        f.write('model: {}\n'.format(model.predictor.__class__.__name__))
        f.write('n_epoch: {}\n'.format(args.epoch))
        f.write('batch_size: {}\n'.format(args.batchsize))
        f.write('n_data_train: {}\n'.format(len(train_data)))
        f.write('n_data_val: {}\n'.format(len(valid_data)))
        f.write('seed: {}\n'.format(args.seed))
        f.write('n_fold: {}\n'.format(args.n_fold))
        f.write('fold: {}\n'.format(args.fold))
        f.write('augmentation: \n')
        for process, param in augmentation:
            f.write('  {}: {}\n'.format(process, param))

    trainer.run()