Exemplo n.º 1
0
def main():
    train_x, train_y, val_x, val_y = load_pascal_voc_dataset(DATASET_ROOT)
    train_dataset = YoloDataset(train_x,
                                train_y,
                                target_size=model_class.img_size,
                                n_grid=model_class.n_grid,
                                augment=True)
    test_dataset = YoloDataset(val_x,
                               val_y,
                               target_size=model_class.img_size,
                               n_grid=model_class.n_grid,
                               augment=False)

    class_weights = [1.0 for i in range(train_dataset.n_classes)]
    class_weights[0] = 0.2
    model = model_class(n_classes=train_dataset.n_classes,
                        n_base_units=6,
                        class_weights=class_weights)
    if os.path.exists(RESULT_DIR + '/model_last.npz'):
        print('continue from previous result')
        chainer.serializers.load_npz(RESULT_DIR + '/model_last.npz', model)
    optimizer = Adam()
    optimizer.setup(model)

    train_iter = SerialIterator(train_dataset, batch_size=BATCH_SIZE)
    test_iter = SerialIterator(test_dataset,
                               batch_size=BATCH_SIZE,
                               shuffle=False,
                               repeat=False)
    updater = StandardUpdater(train_iter, optimizer, device=DEVICE)
    trainer = Trainer(updater, (N_EPOCHS, 'epoch'), out=RESULT_DIR)

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.Evaluator(test_iter, model, device=DEVICE))
    trainer.extend(
        extensions.PrintReport([
            'main/loss',
            'validation/main/loss',
            'main/cl_loss',
            'validation/main/cl_loss',
            'main/cl_acc',
            'validation/main/cl_acc',
            'main/pos_loss',
            'validation/main/pos_loss',
        ]))
    trainer.extend(extensions.snapshot_object(model, 'best_loss.npz'),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(extensions.snapshot_object(model,
                                              'best_classification.npz'),
                   trigger=triggers.MaxValueTrigger('validation/main/cl_acc'))
    trainer.extend(
        extensions.snapshot_object(model, 'best_position.npz'),
        trigger=triggers.MinValueTrigger('validation/main/pos_loss'))
    trainer.extend(extensions.snapshot_object(model, 'model_last.npz'),
                   trigger=(1, 'epoch'))

    trainer.run()
Exemplo n.º 2
0
def setup_record_trigger(training_type):

    if training_type == 'regression' or training_type == 'multi_regression':
        return triggers.MinValueTrigger('validation/main/loss')

    else:
        raise ValueError('Invalid training type: {}'.format(training_type))
Exemplo n.º 3
0
def main():
    #%% Load datasets
    train, valid, test, train_moles, valid_moles, test_moles = load_dataset(
        CTYPE)

    train_gp = train.groupby('molecule_name')
    valid_gp = valid.groupby('molecule_name')
    test_gp = test.groupby('molecule_name')

    #%%
    structures = pd.read_csv(DATA_PATH / 'structures.csv')

    giba_features = pd.read_csv(DATA_PATH / 'unified-features' /
                                'giba_features.csv',
                                index_col=0)
    structures = pd.merge(structures,
                          giba_features.drop(['atom_name', 'x', 'y', 'z'],
                                             axis=1),
                          on=['molecule_name', 'atom_index'])
    norm_col = [
        col for col in structures.columns
        if col not in ['molecule_name', 'atom_index', 'atom', 'x', 'y', 'z']
    ]
    structures[norm_col] = (structures[norm_col] - structures[norm_col].mean()
                            ) / structures[norm_col].std()
    structures = structures.fillna(0)
    structures_groups = structures.groupby('molecule_name')

    #%%
    if CTYPE != 'all':
        train_couple = pd.read_csv(DATA_PATH / 'typewise-dataset' /
                                   'kuma_dataset' / 'kuma_dataset' / 'train' /
                                   '{}_full.csv'.format(CTYPE),
                                   index_col=0)
    else:
        train_couple = pd.read_csv(DATA_PATH / 'typewise-dataset' /
                                   'kuma_dataset' / 'kuma_dataset' /
                                   'train_all.csv',
                                   index_col=0)
    train_couple = reduce_mem_usage(train_couple)
    train_couple = train_couple.drop(
        ['id', 'scalar_coupling_constant', 'type'], axis=1)
    if CTYPE != 'all':
        test_couple = pd.read_csv(DATA_PATH / 'typewise-dataset' /
                                  'kuma_dataset' / 'kuma_dataset' / 'test' /
                                  '{}_full.csv'.format(CTYPE),
                                  index_col=0)
    else:
        test_couple = pd.read_csv(DATA_PATH / 'typewise-dataset' /
                                  'kuma_dataset' / 'kuma_dataset' /
                                  'test_all.csv',
                                  index_col=0)
    test_couple = reduce_mem_usage(test_couple)
    test_couple = test_couple.drop(['id', 'type'], axis=1)

    couples = pd.concat([train_couple, test_couple])

    del train_couple, test_couple

    couples_norm_col = [
        col for col in couples.columns if col not in
        ['atom_index_0', 'atom_index_1', 'molecule_name', 'type']
    ]

    for col in couples_norm_col:
        if couples[col].dtype == np.dtype('O'):
            couples = pd.get_dummies(couples, columns=[col])
        else:
            couples[col] = (couples[col] -
                            couples[col].mean()) / couples[col].std()

    couples = couples.fillna(0)
    couples = couples.replace(np.inf, 0)
    couples = couples.replace(-np.inf, 0)
    couples_groups = couples.groupby('molecule_name')

    #%% Make graphs
    feature_col = [
        col for col in structures.columns
        if col not in ['molecule_name', 'atom_index', 'atom']
    ]

    list_atoms = list(set(structures['atom']))
    print('list of atoms')
    print(list_atoms)

    train_graphs = list()
    train_targets = list()
    train_couples = list()
    print('preprocess training molecules ...')
    for mole in tqdm(train_moles):
        train_graphs.append(
            Graph(structures_groups.get_group(mole), list_atoms, feature_col))
        train_targets.append(train_gp.get_group(mole))
        train_couples.append(couples_groups.get_group(mole))

    valid_graphs = list()
    valid_targets = list()
    valid_couples = list()
    print('preprocess validation molecules ...')
    for mole in tqdm(valid_moles):
        valid_graphs.append(
            Graph(structures_groups.get_group(mole), list_atoms, feature_col))
        valid_targets.append(valid_gp.get_group(mole))
        valid_couples.append(couples_groups.get_group(mole))

    test_graphs = list()
    test_targets = list()
    test_couples = list()
    print('preprocess test molecules ...')
    for mole in tqdm(test_moles):
        test_graphs.append(
            Graph(structures_groups.get_group(mole), list_atoms, feature_col))
        test_targets.append(test_gp.get_group(mole))
        test_couples.append(couples_groups.get_group(mole))

    #%% Make datasets
    train_dataset = DictDataset(graphs=train_graphs,
                                targets=train_targets,
                                couples=train_couples)
    valid_dataset = DictDataset(graphs=valid_graphs,
                                targets=valid_targets,
                                couples=valid_couples)
    test_dataset = DictDataset(graphs=test_graphs,
                               targets=test_targets,
                               couples=test_couples)

    #%% Build Model
    model = SchNet(num_layer=NUM_LAYER)
    model.to_gpu(device=0)

    #%% Sampler
    train_sampler = SameSizeSampler(structures_groups, train_moles, BATCH_SIZE)
    valid_sampler = SameSizeSampler(structures_groups,
                                    valid_moles,
                                    BATCH_SIZE,
                                    use_remainder=True)
    test_sampler = SameSizeSampler(structures_groups,
                                   test_moles,
                                   BATCH_SIZE,
                                   use_remainder=True)

    #%% Iterator, Optimizer
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  BATCH_SIZE,
                                                  order_sampler=train_sampler)

    valid_iter = chainer.iterators.SerialIterator(valid_dataset,
                                                  BATCH_SIZE,
                                                  repeat=False,
                                                  order_sampler=valid_sampler)

    test_iter = chainer.iterators.SerialIterator(test_dataset,
                                                 BATCH_SIZE,
                                                 repeat=False,
                                                 order_sampler=test_sampler)

    optimizer = optimizers.Adam(alpha=1e-3)
    optimizer.setup(model)

    #%% Updater
    if opt.multi_gpu:
        updater = training.updaters.ParallelUpdater(
            train_iter,
            optimizer,
            # The device of the name 'main' is used as a "master", while others are
            # used as slaves. Names other than 'main' are arbitrary.
            devices={
                'main': 0,
                'sub1': 1,
                'sub2': 2,
                'sub3': 3
            },
        )
    else:
        updater = training.StandardUpdater(train_iter,
                                           optimizer,
                                           converter=coupling_converter,
                                           device=0)

    # early_stopping
    stop_trigger = triggers.EarlyStoppingTrigger(
        patients=EARLY_STOPPING_ROUNDS,
        monitor='valid/main/ALL_LogMAE',
        max_trigger=(EPOCH, 'epoch'))
    trainer = training.Trainer(updater, stop_trigger, out=RESULT_PATH)
    # trainer = training.Trainer(updater, (100, 'epoch'), out=RESULT_PATH)

    #%% Evaluator
    trainer.extend(
        TypeWiseEvaluator(iterator=valid_iter,
                          target=model,
                          converter=coupling_converter,
                          name='valid',
                          device=0,
                          is_validate=True))
    trainer.extend(
        TypeWiseEvaluator(iterator=test_iter,
                          target=model,
                          converter=coupling_converter,
                          name='test',
                          device=0,
                          is_submit=True))

    #%% Other extensions
    trainer.extend(training.extensions.ExponentialShift('alpha', 0.99999))

    trainer.extend(stop_train_mode(trigger=(1, 'epoch')))

    trainer.extend(
        training.extensions.observe_value(
            'alpha', lambda tr: tr.updater.get_optimizer('main').alpha))

    trainer.extend(training.extensions.LogReport(log_name=f'log_{CTYPE}'))
    trainer.extend(
        training.extensions.PrintReport([
            'epoch', 'elapsed_time', 'main/loss', 'valid/main/ALL_LogMAE',
            'alpha'
        ]))

    # trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
    trainer.extend(SaveRestore(filename=f'best_epoch_{CTYPE}'),
                   trigger=triggers.MinValueTrigger('valid/main/ALL_LogMAE'))

    #%% Train
    if not opt.test:
        chainer.config.train = True
        trainer.run()
    else:
        chainer.config.train = False
        snapshot_path = f'results/chainer/best_epoch_{CTYPE}'
        chainer.serializers.npz.load_npz(snapshot_path, model,
                                         'updater/model:main/')
        oof = predict_iter(valid_iter, model)
        oof.to_csv(f'schnet_{CTYPE}_oof.csv', index=False)

    #%% Final Evaluation
    chainer.config.train = False
    prediction = predict_iter(test_iter, model)
    prediction.to_csv(f'schnet_{CTYPE}.csv', index=False)
Exemplo n.º 4
0
def main():
    args = parse_args()
    with open(args.config_path) as f:
        config = json.load(f)
    with open(args.app_config) as f:
        app_config = json.load(f)
    app_train_config = app_config.get('train', {})
    command_config = {'gpu': args.gpu}
    if args.output_dir is not None:
        command_config['output_dir'] = args.output_dir
    config.update(command_config)

    device_id = config['gpu']
    batch_size = config['batch_size']

    network_params = config['network']
    nets = {k: util.create_network(v) for k, v in network_params.items()}
    optimizers = {
        k: util.create_optimizer(v['optimizer'], nets[k])
        for k, v in network_params.items()
    }
    if len(optimizers) == 1:
        key, target_optimizer = list(optimizers.items())[0]
        target = nets[key]
    else:
        target = nets
        target_optimizer = optimizers

    if device_id >= 0:
        chainer.cuda.get_device_from_id(device_id).use()
        for net in nets.values():
            net.to_gpu()

    datasets = dataset.get_dataset()
    iterators = {}
    if isinstance(datasets, dict):
        for name, data in datasets.items():
            if name == 'train':
                train_iterator = chainer.iterators.SerialIterator(
                    data, batch_size)
            else:
                iterators[name] = chainer.iterators.SerialIterator(
                    data, batch_size, repeat=False, shuffle=False)
    else:
        train_iterator = chainer.iterators.SerialIterator(datasets, batch_size)
    updater = TrainingStep(train_iterator,
                           target_optimizer,
                           model.calculate_metrics,
                           device=device_id)
    trainer = Trainer(updater, (config['epoch'], 'epoch'),
                      out=config['output_dir'])
    if hasattr(model, 'make_eval_func'):
        for name, iterator in iterators.items():
            evaluator = extensions.Evaluator(
                iterator,
                target,
                eval_func=model.make_eval_func(target),
                device=device_id)
            trainer.extend(evaluator, name=name)

    dump_graph_node = app_train_config.get('dump_graph', None)
    if dump_graph_node is not None:
        trainer.extend(extensions.dump_graph(dump_graph_node))

    trainer.extend(extensions.snapshot(filename='snapshot.state'),
                   trigger=(1, 'epoch'))
    for k, net in nets.items():
        file_name = 'latest.{}.model'.format(k)
        trainer.extend(extensions.snapshot_object(net, filename=file_name),
                       trigger=(1, 'epoch'))
    max_value_trigger_key = app_train_config.get('max_value_trigger', None)
    min_value_trigger_key = app_train_config.get('min_value_trigger', None)
    if max_value_trigger_key is not None:
        trigger = triggers.MaxValueTrigger(max_value_trigger_key)
        for key, net in nets.items():
            file_name = 'best.{}.model'.format(key)
            trainer.extend(extensions.snapshot_object(net, filename=file_name),
                           trigger=trigger)
    elif min_value_trigger_key is not None:
        trigger = triggers.MinValueTrigger(min_value_trigger_key)
        for key, net in nets.items():
            file_name = 'best.{}.model'.format(key)
            trainer.extend(extensions.snapshot_object(net, file_name),
                           trigger=trigger)
    trainer.extend(extensions.LogReport())
    if len(optimizers) == 1:
        for name, opt in optimizers.items():
            if not hasattr(opt, 'lr'):
                continue
            trainer.extend(extensions.observe_lr(name))
    else:
        for name, opt in optimizers.items():
            if not hasattr(opt, 'lr'):
                continue
            key = '{}/lr'.format(name)
            trainer.extend(extensions.observe_lr(name, key))

    if extensions.PlotReport.available():
        plot_targets = app_train_config.get('plot_report', {})
        for name, targets in plot_targets.items():
            file_name = '{}.png'.format(name)
            trainer.extend(
                extensions.PlotReport(targets, 'epoch', file_name=file_name))

    if not args.silent:
        print_targets = app_train_config.get('print_report', [])
        if print_targets is not None and print_targets != []:
            trainer.extend(extensions.PrintReport(print_targets))
        trainer.extend(extensions.ProgressBar())

    trainer.extend(generate_image(nets['gen'], 10, 10,
                                  config['output_image_dir']),
                   trigger=(1, 'epoch'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
Exemplo n.º 5
0
def TrainUNet(X,
              Y,
              model_=None,
              optimizer_=None,
              epoch=40,
              alpha=0.001,
              gpu_id=0,
              loop=1,
              earlystop=True):
    assert (len(X) == len(Y))
    d_time = datetime.datetime.now().strftime("%m-%d-%H-%M-%S")

    # 1. Model load.

    # print(sum(p.data.size for p in model.unet.params()))
    if model_ is not None:
        model = Regressor(model_)
        print("## model loaded.")
    else:
        model = Regressor(UNet())

    model.compute_accuracy = False

    if gpu_id >= 0:
        model.to_gpu(gpu_id)

    # 2. optimizer load.

    if optimizer_ is not None:
        opt = optimizer_
        print("## optimizer loaded.")
    else:
        opt = optimizers.Adam(alpha=alpha)
        opt.setup(model)

    # 3. Data Split.
    dataset = Unet_DataSet(X, Y)
    print("# number of patterns", len(dataset))

    train, valid = \
        split_dataset_random(dataset, int(len(dataset) * 0.8), seed=0)

    # 4. Iterator
    train_iter = SerialIterator(train, batch_size=C.BATCH_SIZE)
    test_iter = SerialIterator(valid,
                               batch_size=C.BATCH_SIZE,
                               repeat=False,
                               shuffle=False)

    # 5. config train, enable backprop
    chainer.config.train = True
    chainer.config.enable_backprop = True

    # 6. UnetUpdater
    updater = UnetUpdater(train_iter, opt, model, device=gpu_id)

    # 7. EarlyStopping
    if earlystop:
        stop_trigger = triggers.EarlyStoppingTrigger(
            monitor='validation/main/loss',
            max_trigger=(epoch, 'epoch'),
            patients=5)
    else:
        stop_trigger = (epoch, 'epoch')

    # 8. Trainer
    trainer = training.Trainer(updater, stop_trigger, out=C.PATH_TRAINRESULT)

    # 8.1. UnetEvaluator
    trainer.extend(UnetEvaluator(test_iter, model, device=gpu_id))

    trainer.extend(SaveRestore(),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))

    # 8.2. Extensions LogReport
    trainer.extend(extensions.LogReport())

    # 8.3. Extension Snapshot
    # trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
    # trainer.extend(extensions.snapshot_object(model.unet, filename='loop' + str(loop) + '.model'))

    # 8.4. Print Report
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'elapsed_time', 'lr'
        ]))

    # 8.5. Extension Graph
    trainer.extend(
        extensions.PlotReport(['main/loss', 'validation/main/loss'],
                              x_key='epoch',
                              file_name='loop-' + str(loop) + '-loss' +
                              d_time + '.png'))
    # trainer.extend(extensions.dump_graph('main/loss'))

    # 8.6. Progree Bar
    trainer.extend(extensions.ProgressBar())

    # 9. Trainer run
    trainer.run()

    chainer.serializers.save_npz(C.PATH_TRAINRESULT / ('loop' + str(loop)),
                                 model.unet)
    return model.unet, opt
def run(dir_dataset: Path, batch_size: int, epochs: int, alpha: float,
        seed: int, debug: bool):

    tic = time.time()

    logger = getLogger('root')

    np.random.seed(seed)
    random.seed(seed)

    model = EdgeUpdateNet()
    model.to_gpu(device=0)

    train_ids, valid_ids, test_ids = load_dataset(dir_dataset)

    logger.info(f'train_ids: {train_ids[:5]} ... {train_ids[-5:]}')
    logger.info(f'valid_ids: {valid_ids[:5]} ... {valid_ids[-5:]}')
    logger.info(f' test_ids: {test_ids[:5]} ... {test_ids[-5:]}')

    train_scores = pd.read_csv(dir_dataset / 'train_scores.csv')
    train_scores.index = train_scores['Id']

    target_cols = [
        'age', 'domain1_var1', 'domain1_var2', 'domain2_var1', 'domain2_var2'
    ]
    train_target = train_scores.loc[train_ids][target_cols].values.astype(
        np.float32)
    valid_target = train_scores.loc[valid_ids][target_cols].values.astype(
        np.float32)
    test_target = np.zeros((len(test_ids), len(target_cols)), dtype=np.float32)

    loading = pd.read_csv(dir_dataset / 'loading.csv')
    loading.index = loading['Id']

    loading_train = loading.loc[train_ids].iloc[:,
                                                1:].values.astype(np.float32)
    loading_valid = loading.loc[valid_ids].iloc[:,
                                                1:].values.astype(np.float32)
    loading_test = loading.loc[test_ids].iloc[:, 1:].values.astype(np.float32)

    fnc_train, fnc_valid, fnc_test = get_fnc(dir_dataset, train_ids, valid_ids,
                                             test_ids, alpha)

    logger.info(f'fnc train: {fnc_train.shape}')
    logger.info(f'fnc valid: {fnc_valid.shape}')
    logger.info(f'fnc  test: {fnc_test.shape}')

    icn_numbers = pd.read_csv('../../input/ICN_numbers.csv')
    feature = np.zeros((53, len(icn_numbers['net_type'].unique())),
                       dtype=np.float32)
    feature[range(len(feature)), icn_numbers['net_type_code']] = 1.0

    net_type_train = np.tile(np.expand_dims(feature, 0),
                             (len(train_ids), 1, 1))
    net_type_valid = np.tile(np.expand_dims(feature, 0),
                             (len(valid_ids), 1, 1))
    net_type_test = np.tile(np.expand_dims(feature, 0), (len(test_ids), 1, 1))

    spatial_map_train, spatial_map_valid = load_spatial_map(
        train_ids, valid_ids)
    spatial_map_test = np.load('../../input/spatial_map_test.npy')

    train_dataset = DictDataset(loading=loading_train,
                                fnc=fnc_train,
                                net_type=net_type_train,
                                spatial_map=spatial_map_train,
                                targets=train_target,
                                Id=train_ids)

    valid_dataset = DictDataset(loading=loading_valid,
                                fnc=fnc_valid,
                                net_type=net_type_valid,
                                spatial_map=spatial_map_valid,
                                targets=valid_target,
                                Id=valid_ids)

    test_dataset = DictDataset(loading=loading_test,
                               fnc=fnc_test,
                               net_type=net_type_test,
                               spatial_map=spatial_map_test,
                               targets=test_target,
                               Id=test_ids)

    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  batch_size,
                                                  shuffle=True)
    valid_iter = chainer.iterators.SerialIterator(valid_dataset,
                                                  batch_size,
                                                  shuffle=False,
                                                  repeat=False)
    test_iter = chainer.iterators.SerialIterator(test_dataset,
                                                 batch_size,
                                                 shuffle=False,
                                                 repeat=False)

    optimizer = optimizers.Adam(alpha=1e-3)
    optimizer.setup(model)

    updater = training.StandardUpdater(train_iter, optimizer, device=0)
    trainer = training.Trainer(updater, (epochs, 'epoch'), out="result")

    trainer.extend(training.extensions.LogReport(filename=f'seed{seed}.log'))

    trainer.extend(training.extensions.ExponentialShift('alpha', 0.99999))
    trainer.extend(
        training.extensions.observe_value(
            'alpha', lambda tr: tr.updater.get_optimizer('main').alpha))

    def stop_train_mode(trigger):
        @make_extension(trigger=trigger)
        def _stop_train_mode(_):
            logger.debug('turn off training mode')
            chainer.config.train = False

        return _stop_train_mode

    trainer.extend(stop_train_mode(trigger=(1, 'epoch')))

    trainer.extend(
        training.extensions.PrintReport(
            ['epoch', 'elapsed_time', 'main/loss', 'valid/main/All', 'alpha']))

    trainer.extend(
        TreNDSEvaluator(iterator=valid_iter,
                        target=model,
                        name='valid',
                        device=0,
                        is_validate=True))

    trainer.extend(TreNDSEvaluator(iterator=test_iter,
                                   target=model,
                                   name='test',
                                   device=0,
                                   is_submit=True,
                                   submission_name=f'submit_seed{seed}.csv'),
                   trigger=triggers.MinValueTrigger('valid/main/All'))

    chainer.config.train = True
    trainer.run()

    trained_result = pd.DataFrame(trainer.get_extension('LogReport').log)
    best_score = np.min(trained_result['valid/main/All'])
    logger.info(f'validation score: {best_score: .4f} (seed: {seed})')

    elapsed_time = time.time() - tic
    logger.info(f'elapsed time: {elapsed_time / 60.0: .1f} [min]')
Exemplo n.º 7
0
def train(args):
    if not os.path.exists(args.out):
        os.makedirs(args.out)
    if args.gpu >= 0:
        cuda.check_cuda_available()
        cuda.get_device(args.gpu).use()
    if args.random_seed:
        set_random_seed(args.random_seed, (args.gpu,))

    user2index = load_dict(os.path.join(args.indir, USER_DICT_FILENAME))
    item2index = load_dict(os.path.join(args.indir, ITEM_DICT_FILENAME))
    (trimmed_word2count, word2index, aspect2index, opinion2index) = read_and_trim_vocab(
        args.indir, args.trimfreq
    )
    aspect_opinions = get_aspect_opinions(os.path.join(args.indir, TRAIN_FILENAME))

    export_params(
        args,
        user2index,
        item2index,
        trimmed_word2count,
        word2index,
        aspect2index,
        opinion2index,
        aspect_opinions,
    )

    src_aspect_score = SOURCE_ASPECT_SCORE.get(args.context, "aspect_score_efm")

    data_loader = DataLoader(
        args.indir,
        user2index,
        item2index,
        trimmed_word2count,
        word2index,
        aspect2index,
        opinion2index,
        aspect_opinions,
        src_aspect_score,
    )

    train_iter, val_iter = get_dataset_iterator(
        args.context, data_loader, args.batchsize
    )

    model = get_context_model(args, data_loader)

    if args.optimizer == "rmsprop":
        optimizer = O.RMSprop(lr=args.learning_rate, alpha=args.alpha)
    elif args.optimizer == "adam":
        optimizer = O.Adam(amsgrad=args.amsgrad)

    optimizer.setup(model)
    if args.grad_clip:
        optimizer.add_hook(GradientClipping(args.grad_clip))
    if args.gpu >= 0:
        model.to_gpu(args.gpu)

    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, converter=convert, device=args.gpu
    )
    early_stop = triggers.EarlyStoppingTrigger(
        monitor="validation/main/loss",
        patients=args.patients,
        max_trigger=(args.epoch, "epoch"),
    )
    trainer = training.Trainer(updater, stop_trigger=early_stop, out=args.out)
    trainer.extend(
        extensions.Evaluator(val_iter, model, converter=convert, device=args.gpu)
    )
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport(
            ["epoch", "main/loss", "validation/main/loss", "lr", "elapsed_time"]
        )
    )
    trainer.extend(
        extensions.PlotReport(
            ["main/loss", "validation/main/loss"], x_key="epoch", file_name="loss.png"
        )
    )
    trainer.extend(extensions.ProgressBar())
    trainer.extend(
        extensions.snapshot_object(model, MODEL_FILENAME),
        trigger=triggers.MinValueTrigger("validation/main/loss"),
    )
    trainer.extend(extensions.observe_lr())

    if args.optimizer in ["rmsprop"]:
        if args.schedule_lr:
            epoch_list = np.array(
                [i for i in range(1, int(args.epoch / args.stepsize) + 1)]
            ).astype(np.int32)
            value_list = args.learning_rate * args.lr_reduce ** epoch_list
            value_list[value_list < args.min_learning_rate] = args.min_learning_rate
            epoch_list *= args.stepsize
            epoch_list += args.begin_step
            trainer.extend(
                schedule_optimizer_value(epoch_list.tolist(), value_list.tolist())
            )

    trainer.run()
Exemplo n.º 8
0
def train(model_class,
          n_base_units,
          trained_model,
          no_obj_weight,
          data,
          result_dir,
          initial_batch_size=10,
          max_batch_size=1000,
          max_epoch=100):
    train_x, train_y, val_x, val_y = data

    max_class_id = 0
    for objs in val_y:
        for obj in objs:
            max_class_id = max(max_class_id, obj[4])
    n_classes = max_class_id + 1

    class_weights = [1.0 for i in range(n_classes)]
    class_weights[0] = no_obj_weight
    train_dataset = YoloDataset(train_x,
                                train_y,
                                target_size=model_class.img_size,
                                n_grid=model_class.n_grid,
                                augment=True,
                                class_weights=class_weights)
    test_dataset = YoloDataset(val_x,
                               val_y,
                               target_size=model_class.img_size,
                               n_grid=model_class.n_grid,
                               augment=False,
                               class_weights=class_weights)

    model = model_class(n_classes, n_base_units)
    model.loss_calc = LossCalculator(n_classes, class_weights=class_weights)

    last_result_file = os.path.join(result_dir, 'best_loss.npz')
    if os.path.exists(last_result_file):
        try:
            chainer.serializers.load_npz(last_result_file, model)
            print('this training has done. resuse the result')
            return model
        except:
            pass

    if trained_model:
        print('copy params from trained model')
        copy_params(trained_model, model)

    optimizer = Adam()
    optimizer.setup(model)

    n_physical_cpu = int(math.ceil(multiprocessing.cpu_count() / 2))

    train_iter = MultiprocessIterator(train_dataset,
                                      batch_size=initial_batch_size,
                                      n_prefetch=n_physical_cpu,
                                      n_processes=n_physical_cpu)
    test_iter = MultiprocessIterator(test_dataset,
                                     batch_size=initial_batch_size,
                                     shuffle=False,
                                     repeat=False,
                                     n_prefetch=n_physical_cpu,
                                     n_processes=n_physical_cpu)
    updater = StandardUpdater(train_iter, optimizer, device=0)
    stopper = triggers.EarlyStoppingTrigger(check_trigger=(1, 'epoch'),
                                            monitor="validation/main/loss",
                                            patients=10,
                                            mode="min",
                                            max_trigger=(max_epoch, "epoch"))
    trainer = Trainer(updater, stopper, out=result_dir)

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.Evaluator(test_iter, model, device=0))
    trainer.extend(
        extensions.PrintReport([
            'epoch',
            'main/loss',
            'validation/main/loss',
            'main/cl_loss',
            'validation/main/cl_loss',
            'main/cl_acc',
            'validation/main/cl_acc',
            'main/pos_loss',
            'validation/main/pos_loss',
        ]))
    trainer.extend(extensions.snapshot_object(model, 'best_loss.npz'),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(extensions.snapshot_object(model,
                                              'best_classification.npz'),
                   trigger=triggers.MaxValueTrigger('validation/main/cl_acc'))
    trainer.extend(
        extensions.snapshot_object(model, 'best_position.npz'),
        trigger=triggers.MinValueTrigger('validation/main/pos_loss'))
    trainer.extend(extensions.snapshot_object(model, 'model_last.npz'),
                   trigger=(1, 'epoch'))
    trainer.extend(AdaptiveBatchsizeIncrement(maxsize=max_batch_size),
                   trigger=(1, 'epoch'))

    trainer.run()

    chainer.serializers.load_npz(os.path.join(result_dir, 'best_loss.npz'),
                                 model)
    return model
Exemplo n.º 9
0
def main(args=None):
    set_random_seed(63)
    chainer.global_config.autotune = True
    chainer.cuda.set_max_workspace_size(512 * 1024 * 1024)
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--learnrate',
                        '-l',
                        type=float,
                        default=0.01,
                        help='Learning rate for SGD')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=80,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--loss-function',
                        choices=['focal', 'sigmoid'],
                        default='focal')
    parser.add_argument('--optimizer',
                        choices=['sgd', 'adam', 'adabound'],
                        default='adam')
    parser.add_argument('--size', type=int, default=224)
    parser.add_argument('--limit', type=int, default=None)
    parser.add_argument('--data-dir', type=str, default='data')
    parser.add_argument('--lr-search', action='store_true')
    parser.add_argument('--pretrained', type=str, default='')
    parser.add_argument('--backbone',
                        choices=['resnet', 'seresnet', 'debug_model'],
                        default='resnet')
    parser.add_argument('--log-interval', type=int, default=100)
    parser.add_argument('--find-threshold', action='store_true')
    parser.add_argument('--finetune', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    args = parser.parse_args() if args is None else parser.parse_args(args)

    print(args)

    if args.mixup and args.loss_function != 'focal':
        raise ValueError('mixupを使うときはfocal lossしか使えません(いまんところ)')

    train, test, cooccurrence = get_dataset(args.data_dir, args.size,
                                            args.limit, args.mixup)
    base_model = backbone_catalog[args.backbone](args.dropout)

    if args.pretrained:
        print('loading pretrained model: {}'.format(args.pretrained))
        chainer.serializers.load_npz(args.pretrained, base_model, strict=False)
    model = TrainChain(base_model,
                       1,
                       loss_fn=args.loss_function,
                       cooccurrence=cooccurrence,
                       co_coef=0)
    if args.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    if args.optimizer in ['adam', 'adabound']:
        optimizer = Adam(alpha=args.learnrate,
                         adabound=args.optimizer == 'adabound',
                         weight_decay_rate=1e-5,
                         gamma=5e-7)
    elif args.optimizer == 'sgd':
        optimizer = chainer.optimizers.MomentumSGD(lr=args.learnrate)

    optimizer.setup(model)

    if not args.finetune:
        print('最初のエポックは特徴抽出層をfreezeします')
        model.freeze_extractor()

    train_iter = chainer.iterators.MultiprocessIterator(train,
                                                        args.batchsize,
                                                        n_processes=8,
                                                        n_prefetch=2)
    test_iter = chainer.iterators.MultithreadIterator(test,
                                                      args.batchsize,
                                                      n_threads=8,
                                                      repeat=False,
                                                      shuffle=False)

    if args.find_threshold:
        # train_iter, optimizerなど無駄なsetupもあるが。。
        print('thresholdを探索して終了します')
        chainer.serializers.load_npz(join(args.out, 'bestmodel_loss'),
                                     base_model)
        print('lossがもっとも小さかったモデルに対しての結果:')
        find_threshold(base_model, test_iter, args.gpu, args.out)

        chainer.serializers.load_npz(join(args.out, 'bestmodel_f2'),
                                     base_model)
        print('f2がもっとも大きかったモデルに対しての結果:')
        find_threshold(base_model, test_iter, args.gpu, args.out)
        return

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter,
        optimizer,
        device=args.gpu,
        converter=lambda batch, device: chainer.dataset.concat_examples(
            batch, device=device))
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(FScoreEvaluator(test_iter, model, device=args.gpu))

    if args.optimizer == 'sgd':
        # Adamにweight decayはあんまりよくないらしい
        optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(5e-4))
        trainer.extend(extensions.ExponentialShift('lr', 0.1),
                       trigger=(3, 'epoch'))
        if args.lr_search:
            print('最適な学習率を探します')
            trainer.extend(LRFinder(1e-7, 1, 5, optimizer),
                           trigger=(1, 'iteration'))
    elif args.optimizer in ['adam', 'adabound']:
        if args.lr_search:
            print('最適な学習率を探します')
            trainer.extend(LRFinder(1e-7, 1, 5, optimizer, lr_key='alpha'),
                           trigger=(1, 'iteration'))

        trainer.extend(extensions.ExponentialShift('alpha', 0.2),
                       trigger=triggers.EarlyStoppingTrigger(
                           monitor='validation/main/loss'))

    # Take a snapshot of Trainer at each epoch
    trainer.extend(
        extensions.snapshot(filename='snaphot_epoch_{.updater.epoch}'),
        trigger=(10, 'epoch'))

    # Take a snapshot of Model which has best val loss.
    # Because searching best threshold for each evaluation takes too much time.
    trainer.extend(extensions.snapshot_object(model.model, 'bestmodel_loss'),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(extensions.snapshot_object(model.model, 'bestmodel_f2'),
                   trigger=triggers.MaxValueTrigger('validation/main/f2'))
    trainer.extend(extensions.snapshot_object(model.model,
                                              'model_{.updater.epoch}'),
                   trigger=(5, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.log_interval, 'iteration')))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'lr', 'elapsed_time', 'main/loss', 'main/co_loss',
            'validation/main/loss', 'validation/main/co_loss',
            'validation/main/precision', 'validation/main/recall',
            'validation/main/f2', 'validation/main/threshold'
        ]))

    trainer.extend(extensions.ProgressBar(update_interval=args.log_interval))
    trainer.extend(extensions.observe_lr(),
                   trigger=(args.log_interval, 'iteration'))
    trainer.extend(CommandsExtension())
    save_args(args, args.out)

    trainer.extend(lambda trainer: model.unfreeze_extractor(),
                   trigger=(1, 'epoch'))

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # save args with pickle for prediction time
    pickle.dump(args, open(str(Path(args.out).joinpath('args.pkl')), 'wb'))

    # Run the training
    trainer.run()

    # find optimal threshold
    chainer.serializers.load_npz(join(args.out, 'bestmodel_loss'), base_model)
    print('lossがもっとも小さかったモデルに対しての結果:')
    find_threshold(base_model, test_iter, args.gpu, args.out)

    chainer.serializers.load_npz(join(args.out, 'bestmodel_f2'), base_model)
    print('f2がもっとも大きかったモデルに対しての結果:')
    find_threshold(base_model, test_iter, args.gpu, args.out)
Exemplo n.º 10
0
def main():
    start_time = time.time()
    ap = ArgumentParser(description='python train_cc.py')
    ap.add_argument('--indir',
                    '-i',
                    nargs='?',
                    default='datasets/train',
                    help='Specify input files directory for learning data')
    ap.add_argument(
        '--outdir',
        '-o',
        nargs='?',
        default='results/results_training_cc',
        help='Specify output files directory for create save model files')
    ap.add_argument('--train_list',
                    nargs='?',
                    default='datasets/split_list/train.list',
                    help='Specify split train list')
    ap.add_argument('--validation_list',
                    nargs='?',
                    default='datasets/split_list/validation.list',
                    help='Specify split validation list')
    ap.add_argument(
        '--init_model',
        help='Specify Loading File Path of Learned Cell Classification Model')
    ap.add_argument('--gpu',
                    '-g',
                    type=int,
                    default=-1,
                    help='Specify GPU ID (negative value indicates CPU)')
    ap.add_argument('--epoch',
                    '-e',
                    type=int,
                    default=10,
                    help='Specify number of sweeps over the dataset to train')
    ap.add_argument('--batchsize',
                    '-b',
                    type=int,
                    default=5,
                    help='Specify Batchsize')
    ap.add_argument('--crop_size',
                    nargs='?',
                    default='(640, 640)',
                    help='Specify crop size (default (y,x) = (640,640))')
    ap.add_argument(
        '--coordinate',
        nargs='?',
        default='(1840, 740)',
        help='Specify initial coordinate (default (y,x) = (1840,700))')
    ap.add_argument('--optimizer',
                    default='SGD',
                    help='Optimizer [SGD, MomentumSGD, Adam]')
    args = ap.parse_args()
    argvs = sys.argv
    psep = '/'

    print('init dataset...')
    train_dataset = PreprocessedRegressionDataset(path=args.indir,
                                                  split_list=args.train_list,
                                                  crop_size=args.crop_size,
                                                  coordinate=args.coordinate,
                                                  train=True)
    validation_dataset = PreprocessedRegressionDataset(
        path=args.indir,
        split_list=args.validation_list,
        crop_size=args.crop_size,
        coordinate=args.coordinate,
        train=False)

    #    import skimage.io as io
    #    for i in range(train_dataset.__len__()):
    #        img, label = train_dataset.get_example(i)
    #        print(label)
    #        io.imsave('dataset_roi/roi_r_{}_{0:03d}.tif'.format(label[0], i+1), np.array(img[0:3] * 255).astype(np.uint8).transpose(0, 1, 2))
    #        io.imsave('dataset_roi/roi_l_{}_{0:03d}.tif'.format(label[0], i+1), np.array(img[3:6] * 255).astype(np.uint8).transpose(0, 1, 2))

    print('init model construction')
    model = Regressor(CCNet(n_class=1), lossfun=F.mean_squared_error)

    if args.init_model is not None:
        print('Load model from', args.init_model)
        chainer.serializers.load_npz(args.init_model, model)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    print('init optimizer...')
    if args.optimizer == 'SGD':
        optimizer = chainer.optimizers.SGD(lr=0.00001)
    elif args.optimizer == 'MomentumSGD':
        optimizer = chainer.optimizers.MomentumSGD(lr=0.00001)
    elif args.optimizer == 'Adam':
        optimizer = chainer.optimizers.Adam()
    else:
        print('Specify optimizer name')
        sys.exit()
    optimizer.setup(model)
    # optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(rate=0.0001))
    optimizer.add_hook(chainer.optimizer_hooks.Lasso(rate=0.0005))
    # optimizer.add_hook(chainer.optimizer_hooks.Lasso(rate=0.001))
    # optimizer.add_hook(chainer.optimizer_hooks.Lasso(rate=0.0001))
    ''' Updater '''
    print('init updater')
    train_iter = chainer.iterators.SerialIterator(train_dataset,
                                                  batch_size=args.batchsize)
    validation_iter = chainer.iterators.SerialIterator(validation_dataset,
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)
    updater = chainer.training.updaters.StandardUpdater(train_iter,
                                                        optimizer,
                                                        device=args.gpu)
    ''' Trainer '''
    current_datetime = datetime.now(
        pytz.timezone('Asia/Tokyo')).strftime('%Y%m%d_%H%M%S')
    save_dir = args.outdir + '_' + str(current_datetime)
    os.makedirs(save_dir, exist_ok=True)
    trainer = training.Trainer(updater,
                               stop_trigger=(args.epoch, 'epoch'),
                               out=save_dir)
    '''
    Extensions:
        Evaluator         : Evaluate the segmentor with the validation dataset for each epoch
        ProgressBar       : print a progress bar and recent training status.
        ExponentialShift  : The typical use case is an exponential decay of the learning rate.
        dump_graph        : This extension dumps a computational graph.
        snapshot          : serializes the trainer object and saves it to the output directory
        snapshot_object   : serializes the given object and saves it to the output directory.
        LogReport         : output the accumulated results to a log file.
        PrintReport       : print the accumulated results.
        PlotReport        : output plots.
    '''

    evaluator = extensions.Evaluator(validation_iter, model, device=args.gpu)
    trainer.extend(evaluator, trigger=(1, 'epoch'))

    if args.optimizer == 'SGD':
        trainer.extend(extensions.ExponentialShift('lr', 0.1),
                       trigger=(50, 'epoch'))

    trigger = triggers.MinValueTrigger('validation/main/loss',
                                       trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model,
                                              filename='best_loss_model'),
                   trigger=trigger)

    trainer.extend(chainer.training.extensions.observe_lr(),
                   trigger=(1, 'epoch'))

    # LogReport
    trainer.extend(extension=extensions.LogReport())

    # PrintReport
    trainer.extend(extension=extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'elapsed_time'
    ]))

    # PlotReport
    trainer.extend(extension=extensions.PlotReport(
        ['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
    # trainer.extend(
    #     extension=extensions.PlotReport(
    #         ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.run()
Exemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(description='training mnist')
    parser.add_argument('--gpu',
                        '-g',
                        default=-1,
                        type=int,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=300,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--seed',
                        '-s',
                        type=int,
                        default=0,
                        help='Random seed')
    parser.add_argument('--n_fold',
                        '-nf',
                        type=int,
                        default=5,
                        help='n_fold cross validation')
    parser.add_argument('--fold', '-f', type=int, default=1)
    parser.add_argument('--out_dir_name',
                        '-dn',
                        type=str,
                        default=None,
                        help='Name of the output directory')
    parser.add_argument('--report_trigger',
                        '-rt',
                        type=str,
                        default='1e',
                        help='Interval for reporting(Ex.100i, default:1e)')
    parser.add_argument('--save_trigger',
                        '-st',
                        type=str,
                        default='1e',
                        help='Interval for saving the model'
                        '(Ex.100i, default:1e)')
    parser.add_argument('--load_model',
                        '-lm',
                        type=str,
                        default=None,
                        help='Path of the model object to load')
    parser.add_argument('--load_optimizer',
                        '-lo',
                        type=str,
                        default=None,
                        help='Path of the optimizer object to load')
    args = parser.parse_args()

    if args.out_dir_name is None:
        start_time = datetime.now()
        out_dir = Path('output/{}'.format(start_time.strftime('%Y%m%d_%H%M')))
    else:
        out_dir = Path('output/{}'.format(args.out_dir_name))

    random.seed(args.seed)
    np.random.seed(args.seed)
    cupy.random.seed(args.seed)
    chainer.config.cudnn_deterministic = True

    # model = ModifiedClassifier(SEResNeXt50())
    # model = ModifiedClassifier(SERes2Net50())
    model = ModifiedClassifier(SEResNeXt101())

    if args.load_model is not None:
        serializers.load_npz(args.load_model, model)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    optimizer = optimizers.MomentumSGD(lr=0.1, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(1e-4))
    if args.load_optimizer is not None:
        serializers.load_npz(args.load_optimizer, optimizer)

    n_fold = args.n_fold
    slices = [slice(i, None, n_fold) for i in range(n_fold)]
    fold = args.fold - 1

    # model1
    # augmentation = [
    #     ('Rotate', {'p': 0.8, 'limit': 5}),
    #     ('PadIfNeeded', {'p': 0.5, 'min_height': 28, 'min_width': 30}),
    #     ('PadIfNeeded', {'p': 0.5, 'min_height': 30, 'min_width': 28}),
    #     ('Resize', {'p': 1.0, 'height': 28, 'width': 28}),
    #     ('RandomScale', {'p': 1.0, 'scale_limit': 0.1}),
    #     ('PadIfNeeded', {'p': 1.0, 'min_height': 32, 'min_width': 32}),
    #     ('RandomCrop', {'p': 1.0, 'height': 28, 'width': 28}),
    #     ('Mixup', {'p': 0.5}),
    #     ('Cutout', {'p': 0.5, 'num_holes': 4, 'max_h_size': 4,
    #                 'max_w_size': 4}),
    # ]
    # resize = None

    # model2
    augmentation = [
        ('Rotate', {
            'p': 0.8,
            'limit': 5
        }),
        ('PadIfNeeded', {
            'p': 0.5,
            'min_height': 28,
            'min_width': 32
        }),
        ('PadIfNeeded', {
            'p': 0.5,
            'min_height': 32,
            'min_width': 28
        }),
        ('Resize', {
            'p': 1.0,
            'height': 32,
            'width': 32
        }),
        ('RandomScale', {
            'p': 1.0,
            'scale_limit': 0.1
        }),
        ('PadIfNeeded', {
            'p': 1.0,
            'min_height': 36,
            'min_width': 36
        }),
        ('RandomCrop', {
            'p': 1.0,
            'height': 32,
            'width': 32
        }),
        ('Mixup', {
            'p': 0.5
        }),
        ('Cutout', {
            'p': 0.5,
            'num_holes': 4,
            'max_h_size': 4,
            'max_w_size': 4
        }),
    ]
    resize = [('Resize', {'p': 1.0, 'height': 32, 'width': 32})]

    train_data = KMNIST(augmentation=augmentation,
                        drop_index=slices[fold],
                        pseudo_labeling=True)
    valid_data = KMNIST(augmentation=resize, index=slices[fold])

    train_iter = iterators.SerialIterator(train_data, args.batchsize)
    valid_iter = iterators.SerialIterator(valid_data,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    updater = StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = Trainer(updater, (args.epoch, 'epoch'), out=out_dir)

    report_trigger = (int(args.report_trigger[:-1]), 'iteration'
                      if args.report_trigger[-1] == 'i' else 'epoch')
    trainer.extend(extensions.LogReport(trigger=report_trigger))
    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu),
                   name='val',
                   trigger=report_trigger)
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'main/accuracy', 'val/main/loss',
        'val/main/accuracy', 'elapsed_time'
    ]),
                   trigger=report_trigger)
    trainer.extend(
        extensions.PlotReport(['main/loss', 'val/main/loss'],
                              x_key=report_trigger[1],
                              marker='.',
                              file_name='loss.png',
                              trigger=report_trigger))
    trainer.extend(
        extensions.PlotReport(['main/accuracy', 'val/main/accuracy'],
                              x_key=report_trigger[1],
                              marker='.',
                              file_name='accuracy.png',
                              trigger=report_trigger))

    save_trigger = (int(args.save_trigger[:-1]),
                    'iteration' if args.save_trigger[-1] == 'i' else 'epoch')
    trainer.extend(extensions.snapshot_object(
        model,
        filename='model_{0}-{{.updater.{0}}}.npz'.format(save_trigger[1])),
                   trigger=save_trigger)
    trainer.extend(extensions.snapshot_object(
        optimizer,
        filename='optimizer_{0}-{{.updater.{0}}}.npz'.format(save_trigger[1])),
                   trigger=save_trigger)
    trainer.extend(extensions.ProgressBar())
    trainer.extend(CosineAnnealing(lr_max=0.1, lr_min=1e-6, T_0=20),
                   trigger=(1, 'epoch'))

    best_model_trigger = triggers.MaxValueTrigger('val/main/accuracy',
                                                  trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model,
                                              filename='best_model.npz'),
                   trigger=best_model_trigger)
    trainer.extend(extensions.snapshot_object(optimizer,
                                              filename='best_optimizer.npz'),
                   trigger=best_model_trigger)
    best_loss_model_trigger = triggers.MinValueTrigger('val/main/loss',
                                                       trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model,
                                              filename='best_loss_model.npz'),
                   trigger=best_loss_model_trigger)
    trainer.extend(extensions.snapshot_object(
        optimizer, filename='best_loss_optimizer.npz'),
                   trigger=best_loss_model_trigger)

    if out_dir.exists():
        shutil.rmtree(out_dir)
    out_dir.mkdir()

    # Write parameters text
    with open(out_dir / 'train_params.txt', 'w') as f:
        f.write('model: {}\n'.format(model.predictor.__class__.__name__))
        f.write('n_epoch: {}\n'.format(args.epoch))
        f.write('batch_size: {}\n'.format(args.batchsize))
        f.write('n_data_train: {}\n'.format(len(train_data)))
        f.write('n_data_val: {}\n'.format(len(valid_data)))
        f.write('seed: {}\n'.format(args.seed))
        f.write('n_fold: {}\n'.format(args.n_fold))
        f.write('fold: {}\n'.format(args.fold))
        f.write('augmentation: \n')
        for process, param in augmentation:
            f.write('  {}: {}\n'.format(process, param))

    trainer.run()