Beispiel #1
0
def main():
    # Parse the arguments.
    args = parse_arguments()
    augment = False if args.augment == 'False' else True
    multi_gpu = False if args.multi_gpu == 'False' else True
    if args.label:
        labels = args.label
        class_num = len(labels) if isinstance(labels, list) else 1
    else:
        raise ValueError('No target label was specified.')

    # Dataset preparation. Postprocessing is required for the regression task.
    def postprocess_label(label_list):
        label_arr = np.asarray(label_list, dtype=np.int32)
        return label_arr

    # Apply a preprocessor to the dataset.
    logging.info('Preprocess train dataset and test dataset...')
    preprocessor = preprocess_method_dict[args.method]()
    parser = CSVFileParserForPair(preprocessor, postprocess_label=postprocess_label,
                                  labels=labels, smiles_cols=['smiles_1', 'smiles_2'])
    train = parser.parse(args.train_datafile)['dataset']
    valid = parser.parse(args.valid_datafile)['dataset']

    if augment:
        logging.info('Utilizing data augmentation in train set')
        train = augment_dataset(train)

    num_train = train.get_datasets()[0].shape[0]
    num_valid = valid.get_datasets()[0].shape[0]
    logging.info('Train/test split: {}/{}'.format(num_train, num_valid))

    if len(args.net_hidden_dims):
        net_hidden_dims = tuple([int(net_hidden_dim) for net_hidden_dim in args.net_hidden_dims.split(',')])
    else:
        net_hidden_dims = ()
    fp_attention = True if args.fp_attention else False
    update_attention = True if args.update_attention else False
    weight_tying = False if args.weight_tying == 'False' else True
    attention_tying = False if args.attention_tying == 'False' else True
    fp_batch_normalization = True if args.fp_bn == 'True' else False
    layer_aggregator = None if args.layer_aggregator == '' else args.layer_aggregator
    context = False if args.context == 'False' else True
    output_activation = functions.relu if args.output_activation == 'relu' else None
    predictor = set_up_predictor(method=args.method,
                                 fp_hidden_dim=args.fp_hidden_dim, fp_out_dim=args.fp_out_dim, conv_layers=args.conv_layers,
                                 concat_hidden=args.concat_hidden, layer_aggregator=layer_aggregator,
                                 fp_dropout_rate=args.fp_dropout_rate, fp_batch_normalization=fp_batch_normalization,
                                 net_hidden_dims=net_hidden_dims, class_num=class_num,
                                 sim_method=args.sim_method, fp_attention=fp_attention, weight_typing=weight_tying, attention_tying=attention_tying,
                                 update_attention=update_attention,
                                 context=context, context_layers=args.context_layers, context_dropout=args.context_dropout,
                                 message_function=args.message_function, readout_function=args.readout_function,
                                 num_timesteps=args.num_timesteps, num_output_hidden_layers=args.num_output_hidden_layers,
                                 output_hidden_dim=args.output_hidden_dim, output_activation=output_activation,
                                 symmetric=args.symmetric
                                 )

    train_iter = SerialIterator(train, args.batchsize)
    test_iter = SerialIterator(valid, args.batchsize,
                              repeat=False, shuffle=False)

    metrics_fun = {'accuracy': F.binary_accuracy}
    classifier = Classifier(predictor, lossfun=F.sigmoid_cross_entropy,
                            metrics_fun=metrics_fun, device=args.gpu)

    # Set up the optimizer.
    optimizer = optimizers.Adam(alpha=args.learning_rate, weight_decay_rate=args.weight_decay_rate)
    # optimizer = optimizers.Adam()
    # optimizer = optimizers.SGD(lr=args.learning_rate)
    optimizer.setup(classifier)
    # add regularization
    if args.max_norm > 0:
        optimizer.add_hook(chainer.optimizer.GradientClipping(threshold=args.max_norm))
    if args.l2_rate > 0:
        optimizer.add_hook(chainer.optimizer.WeightDecay(rate=args.l2_rate))
    if args.l1_rate > 0:
        optimizer.add_hook(chainer.optimizer.Lasso(rate=args.l1_rate))

    # Set up the updater.
    if multi_gpu:
        logging.info('Using multiple GPUs')
        updater = training.ParallelUpdater(train_iter, optimizer, devices={'main': 0, 'second': 1},
                                           converter=concat_mols)
    else:
        logging.info('Using single GPU')
        updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu,
                                           converter=concat_mols)

    # Set up the trainer.
    logging.info('Training...')
    # add stop_trigger parameter
    early_stop = triggers.EarlyStoppingTrigger(monitor='validation/main/loss', patients=30, max_trigger=(500, 'epoch'))
    out = 'output' + '/' + args.out
    trainer = training.Trainer(updater, stop_trigger=early_stop, out=out)

    # trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(E.Evaluator(test_iter, classifier,
                               device=args.gpu, converter=concat_mols))

    train_eval_iter = SerialIterator(train, args.batchsize,
                                       repeat=False, shuffle=False)

    trainer.extend(AccuracyEvaluator(
        train_eval_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='train_acc',
        pos_labels=1, ignore_labels=-1, raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(AccuracyEvaluator(
        test_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='val_acc',
        pos_labels=1, ignore_labels=-1))

    trainer.extend(ROCAUCEvaluator(
        train_eval_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='train_roc',
        pos_labels=1, ignore_labels=-1, raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(ROCAUCEvaluator(
        test_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='val_roc',
        pos_labels=1, ignore_labels=-1))

    trainer.extend(PRCAUCEvaluator(
        train_eval_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='train_prc',
        pos_labels=1, ignore_labels=-1, raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(PRCAUCEvaluator(
        test_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='val_prc',
        pos_labels=1, ignore_labels=-1))

    trainer.extend(F1Evaluator(
        train_eval_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='train_f',
        pos_labels=1, ignore_labels=-1, raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(F1Evaluator(
        test_iter, classifier, eval_func=predictor,
        device=args.gpu, converter=concat_mols, name='val_f',
        pos_labels=1, ignore_labels=-1))

    # apply shift strategy to learning rate every 10 epochs
    # trainer.extend(E.ExponentialShift('alpha', args.exp_shift_rate), trigger=(10, 'epoch'))
    if args.exp_shift_strategy == 1:
        trainer.extend(E.ExponentialShift('alpha', args.exp_shift_rate),
                       trigger=triggers.ManualScheduleTrigger([10, 20, 30, 40, 50, 60], 'epoch'))
    elif args.exp_shift_strategy == 2:
        trainer.extend(E.ExponentialShift('alpha', args.exp_shift_rate),
                       trigger=triggers.ManualScheduleTrigger([5, 10, 15, 20, 25, 30], 'epoch'))
    elif args.exp_shift_strategy == 3:
        trainer.extend(E.ExponentialShift('alpha', args.exp_shift_rate),
                       trigger=triggers.ManualScheduleTrigger([5, 10, 15, 20, 25, 30, 40, 50, 60, 70], 'epoch'))
    else:
        raise ValueError('No such strategy to adapt learning rate')
    # # observation of learning rate
    trainer.extend(E.observe_lr(), trigger=(1, 'iteration'))

    entries = [
        'epoch',
        'main/loss', 'train_acc/main/accuracy', 'train_roc/main/roc_auc', 'train_prc/main/prc_auc',
        # 'train_p/main/precision', 'train_r/main/recall',
        'train_f/main/f1',
        'validation/main/loss', 'val_acc/main/accuracy', 'val_roc/main/roc_auc', 'val_prc/main/prc_auc',
        # 'val_p/main/precision', 'val_r/main/recall',
        'val_f/main/f1',
        'lr',
        'elapsed_time']
    trainer.extend(E.PrintReport(entries=entries))
    # change from 10 to 2 on Mar. 1 2019
    trainer.extend(E.snapshot(), trigger=(2, 'epoch'))
    trainer.extend(E.LogReport())
    trainer.extend(E.ProgressBar())
    trainer.extend(E.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
    trainer.extend(E.PlotReport(['train_acc/main/accuracy', 'val_acc/main/accuracy'], 'epoch', file_name='accuracy.png'))

    if args.resume:
        resume_path = os.path.join(out, args.resume)
        logging.info('Resume training according to snapshot in {}'.format(resume_path))
        chainer.serializers.load_npz(resume_path, trainer)

    trainer.run()

    # Save the regressor's parameters.
    model_path = os.path.join(out, args.model_filename)
    logging.info('Saving the trained models to {}...'.format(model_path))
    classifier.save_pickle(model_path, protocol=args.protocol)
Beispiel #2
0
def main():
    # Parse the arguments.
    args = parse_arguments()

    if args.label:
        labels = args.label
        class_num = len(labels) if isinstance(labels, list) else 1
    else:
        raise ValueError('No target label was specified.')

    # Dataset preparation. Postprocessing is required for the regression task.
    def postprocess_label(label_list):
        label_arr = np.asarray(label_list, dtype=np.int32)
        return label_arr

    # Apply a preprocessor to the dataset.
    print('Preprocessing dataset...')
    preprocessor = preprocess_method_dict[args.method]()
    parser = CSVFileParserForPair(preprocessor,
                                  postprocess_label=postprocess_label,
                                  labels=labels,
                                  smiles_cols=['smiles_1', 'smiles_2'])
    dataset = parser.parse(args.datafile)['dataset']

    # Split the dataset into training and validation.
    train_data_size = int(len(dataset) * args.train_data_ratio)
    train, val = split_dataset_random(dataset, train_data_size, args.seed)

    # Set up the predictor.
    predictor = set_up_predictor(args.method, args.unit_num, args.conv_layers,
                                 class_num)

    # Set up the iterator.
    train_iter = SerialIterator(train, args.batchsize)
    val_iter = SerialIterator(val, args.batchsize, repeat=False, shuffle=False)

    # Set up the regressor.
    metric_fun = {
        'accuracy': F.accuracy,
        # 'precision': F.precision,
        # 'recall': F.recall,
        # 'F1-score': F.f1_score,
    }
    classifier = Classifier(predictor,
                            lossfun=F.sigmoid_cross_entropy,
                            metrics_fun=F.accuracy,
                            device=args.gpu)

    # Set up the optimizer.
    optimizer = optimizers.Adam()
    optimizer.setup(classifier)

    # Set up the updater.
    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       device=args.gpu,
                                       converter=concat_mols)
    # updater = training.ParallelUpdater(train_iter, optimizer, devices={'main': 0, 'second': 1},
    #                                    converter=concat_mols)

    # Set up the trainer.
    print('Training...')
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(
        E.Evaluator(val_iter,
                    classifier,
                    device=args.gpu,
                    converter=concat_mols))

    train_eval_iter = SerialIterator(train,
                                     args.batchsize,
                                     repeat=False,
                                     shuffle=False)
    trainer.extend(
        ROCAUCEvaluator(train_eval_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='train_roc',
                        pos_labels=1,
                        ignore_labels=-1,
                        raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        ROCAUCEvaluator(val_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='val_roc',
                        pos_labels=1,
                        ignore_labels=-1))

    trainer.extend(
        PRCAUCEvaluator(train_eval_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='train_prc',
                        pos_labels=1,
                        ignore_labels=-1,
                        raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        PRCAUCEvaluator(val_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='val_prc',
                        pos_labels=1,
                        ignore_labels=-1))

    trainer.extend(
        E.PrintReport([
            'epoch',
            'main/loss',
            'main/accuracy',
            # 'train_roc/main/roc_auc', 'train_prc/main/prc_auc',
            'validation/main/loss',
            'validation/main/accuracy',
            # 'val_roc/main/roc_auc', 'val_prc/main/prc_auc',
            'elapsed_time'
        ]))
    trainer.extend(E.snapshot(), trigger=(args.epoch, 'epoch'))
    trainer.extend(E.LogReport())
    trainer.extend(E.ProgressBar())
    trainer.run()

    # Save the regressor's parameters.
    model_path = os.path.join(args.out, args.model_filename)
    print('Saving the trained models to {}...'.format(model_path))
    classifier.save_pickle(model_path, protocol=args.protocol)
def main():
    # Parse the arguments.
    args = parse_arguments()
    if args['label']:
        labels = args['label']
        class_num = len(labels) if isinstance(labels, list) else 1
    else:
        raise ValueError('No target label was specified.')

    # Dataset preparation. Postprocessing is required for the regression task.
    def postprocess_label(label_list):
        label_arr = np.asarray(label_list, dtype=np.int32)
        return label_arr

    # Apply a preprocessor to the dataset.
    logging.info('Preprocess train dataset and valid dataset...')
    # use `ggnn` for the time being
    preprocessor = preprocess_method_dict['ggnn']()
    # parser = CSVFileParserForPair(preprocessor, postprocess_label=postprocess_label,
    #                               labels=labels, smiles_cols=['smiles_1', 'smiles_2'])
    if args['feature'] == 'molenc':
        parser = MolAutoencoderParserForPair(
            preprocessor,
            postprocess_label=postprocess_label,
            labels=labels,
            smiles_cols=['smiles_1', 'smiles_2'])
    if args['feature'] == 'ssp':
        parser = SSPParserForPair(preprocessor,
                                  postprocess_label=postprocess_label,
                                  labels=labels,
                                  smiles_cols=['smiles_1', 'smiles_2'])
    else:
        parser = Mol2VecParserForPair(preprocessor,
                                      postprocess_label=postprocess_label,
                                      labels=labels,
                                      smiles_cols=['smiles_1', 'smiles_2'])
    train = parser.parse(args['train_datafile'])['dataset']
    valid = parser.parse(args['valid_datafile'])['dataset']

    if args['augment']:
        logging.info('Utilizing data augmentation in train set')
        train = augment_dataset(train)

    num_train = train.get_datasets()[0].shape[0]
    num_valid = valid.get_datasets()[0].shape[0]
    logging.info('Train/test split: {}/{}'.format(num_train, num_valid))

    if len(args['net_hidden_dims']):
        net_hidden_dims = tuple([
            int(net_hidden_dim)
            for net_hidden_dim in args['net_hidden_dims'].split(',')
        ])
    else:
        net_hidden_dims = ()

    predictor = set_up_predictor(fp_out_dim=args['fp_out_dim'],
                                 net_hidden_dims=net_hidden_dims,
                                 class_num=class_num,
                                 sim_method=args['sim_method'],
                                 symmetric=args['symmetric'])

    train_iter = SerialIterator(train, args['batchsize'])
    test_iter = SerialIterator(valid,
                               args['batchsize'],
                               repeat=False,
                               shuffle=False)

    metrics_fun = {'accuracy': F.binary_accuracy}
    classifier = Classifier(predictor,
                            lossfun=F.sigmoid_cross_entropy,
                            metrics_fun=metrics_fun,
                            device=args['gpu'])

    # Set up the optimizer.
    optimizer = optimizers.Adam(alpha=args['learning_rate'],
                                weight_decay_rate=args['weight_decay_rate'])
    # optimizer = optimizers.Adam()
    # optimizer = optimizers.SGD(lr=args.learning_rate)
    optimizer.setup(classifier)
    # add regularization
    if args['max_norm'] > 0:
        optimizer.add_hook(
            chainer.optimizer.GradientClipping(threshold=args['max_norm']))
    if args['l2_rate'] > 0:
        optimizer.add_hook(chainer.optimizer.WeightDecay(rate=args['l2_rate']))
    if args['l1_rate'] > 0:
        optimizer.add_hook(chainer.optimizer.Lasso(rate=args['l1_rate']))

    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       device=args['gpu'],
                                       converter=concat_mols)

    # Set up the trainer.
    logging.info('Training...')
    # add stop_trigger parameter
    early_stop = triggers.EarlyStoppingTrigger(monitor='validation/main/loss',
                                               patients=10,
                                               max_trigger=(500, 'epoch'))
    out = 'output' + '/' + args['out']
    trainer = training.Trainer(updater, stop_trigger=early_stop, out=out)

    trainer.extend(
        E.Evaluator(test_iter,
                    classifier,
                    device=args['gpu'],
                    converter=concat_mols))

    train_eval_iter = SerialIterator(train,
                                     args['batchsize'],
                                     repeat=False,
                                     shuffle=False)

    trainer.extend(
        AccuracyEvaluator(train_eval_iter,
                          classifier,
                          eval_func=predictor,
                          device=args['gpu'],
                          converter=concat_mols,
                          name='train_acc',
                          pos_labels=1,
                          ignore_labels=-1,
                          raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        AccuracyEvaluator(test_iter,
                          classifier,
                          eval_func=predictor,
                          device=args['gpu'],
                          converter=concat_mols,
                          name='val_acc',
                          pos_labels=1,
                          ignore_labels=-1))

    trainer.extend(
        ROCAUCEvaluator(train_eval_iter,
                        classifier,
                        eval_func=predictor,
                        device=args['gpu'],
                        converter=concat_mols,
                        name='train_roc',
                        pos_labels=1,
                        ignore_labels=-1,
                        raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        ROCAUCEvaluator(test_iter,
                        classifier,
                        eval_func=predictor,
                        device=args['gpu'],
                        converter=concat_mols,
                        name='val_roc',
                        pos_labels=1,
                        ignore_labels=-1))

    trainer.extend(
        PRCAUCEvaluator(train_eval_iter,
                        classifier,
                        eval_func=predictor,
                        device=args['gpu'],
                        converter=concat_mols,
                        name='train_prc',
                        pos_labels=1,
                        ignore_labels=-1,
                        raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        PRCAUCEvaluator(test_iter,
                        classifier,
                        eval_func=predictor,
                        device=args['gpu'],
                        converter=concat_mols,
                        name='val_prc',
                        pos_labels=1,
                        ignore_labels=-1))

    trainer.extend(
        F1Evaluator(train_eval_iter,
                    classifier,
                    eval_func=predictor,
                    device=args['gpu'],
                    converter=concat_mols,
                    name='train_f',
                    pos_labels=1,
                    ignore_labels=-1,
                    raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        F1Evaluator(test_iter,
                    classifier,
                    eval_func=predictor,
                    device=args['gpu'],
                    converter=concat_mols,
                    name='val_f',
                    pos_labels=1,
                    ignore_labels=-1))

    # apply shift strategy to learning rate every 10 epochs
    # trainer.extend(E.ExponentialShift('alpha', args.exp_shift_rate), trigger=(10, 'epoch'))
    if args['exp_shift_strategy'] == 1:
        trainer.extend(E.ExponentialShift('alpha', args['exp_shift_rate']),
                       trigger=triggers.ManualScheduleTrigger(
                           [10, 20, 30, 40, 50, 60], 'epoch'))
    elif args['exp_shift_strategy'] == 2:
        trainer.extend(E.ExponentialShift('alpha', args['exp_shift_rate']),
                       trigger=triggers.ManualScheduleTrigger(
                           [5, 10, 15, 20, 25, 30], 'epoch'))
    elif args['exp_shift_strategy'] == 3:
        trainer.extend(E.ExponentialShift('alpha', args['exp_shift_rate']),
                       trigger=triggers.ManualScheduleTrigger(
                           [5, 10, 15, 20, 25, 30, 40, 50, 60, 70], 'epoch'))
    else:
        raise ValueError('No such strategy to adapt learning rate')
    # # observation of learning rate
    trainer.extend(E.observe_lr(), trigger=(1, 'iteration'))

    entries = [
        'epoch',
        'main/loss',
        'train_acc/main/accuracy',
        'train_roc/main/roc_auc',
        'train_prc/main/prc_auc',
        # 'train_p/main/precision', 'train_r/main/recall',
        'train_f/main/f1',
        'validation/main/loss',
        'val_acc/main/accuracy',
        'val_roc/main/roc_auc',
        'val_prc/main/prc_auc',
        # 'val_p/main/precision', 'val_r/main/recall',
        'val_f/main/f1',
        'lr',
        'elapsed_time'
    ]
    trainer.extend(E.PrintReport(entries=entries))
    # change from 10 to 2 on Mar. 1 2019
    trainer.extend(E.snapshot(), trigger=(2, 'epoch'))
    trainer.extend(E.LogReport())
    trainer.extend(E.ProgressBar())
    trainer.extend(
        E.PlotReport(['main/loss', 'validation/main/loss'],
                     'epoch',
                     file_name='loss.png'))
    trainer.extend(
        E.PlotReport(['train_acc/main/accuracy', 'val_acc/main/accuracy'],
                     'epoch',
                     file_name='accuracy.png'))

    if args['resume']:
        resume_path = os.path.join(out, args['resume'])
        logging.info(
            'Resume training according to snapshot in {}'.format(resume_path))
        chainer.serializers.load_npz(resume_path, trainer)

    trainer.run()

    # Save the regressor's parameters.
    model_path = os.path.join(out, args['model_filename'])
    logging.info('Saving the trained models to {}...'.format(model_path))
    classifier.save_pickle(model_path, protocol=args['protocol'])
Beispiel #4
0
def main():
    # Parse the arguments.
    args = parse_arguments()

    if args.label:
        labels = args.label
        class_num = len(labels) if isinstance(labels, list) else 1
    else:
        raise ValueError('No target label was specified.')
    # Dataset preparation. Postprocessing is required for the regression task.
    def postprocess_label(label_list):
        label_arr = np.asarray(label_list, dtype=np.int32)
        return label_arr

    # Apply a preprocessor to the dataset.
    print('Preprocessing dataset...')
    preprocessor = preprocess_method_dict[args.method]()
    parser = CSVFileParserForPair(preprocessor,
                                  postprocess_label=postprocess_label,
                                  labels=labels,
                                  smiles_cols=['smiles_1', 'smiles_2'])
    dataset = parser.parse(args.datafile)['dataset']

    # Split the dataset into training and validation.
    train_data_size = int(len(dataset) * args.train_data_ratio)
    train, val = split_dataset_random(dataset, train_data_size, args.seed)

    # Set up the predictor.
    # def set_up_predictor(method, fp_hidden_dim, fp_out_dim, conv_layers, net_hidden_num, class_num, net_layers):
    # predictor = set_up_predictor(args.method, args.unit_num,
    #                              args.conv_layers, class_num)
    if len(args.net_hidden_dims):
        net_hidden_dims = tuple([
            int(net_hidden_dim)
            for net_hidden_dim in args.net_hidden_dims.split(',')
        ])
    else:
        net_hidden_dims = ()
    predictor = set_up_predictor(method=args.method,
                                 fp_hidden_dim=args.fp_hidden_dim,
                                 fp_out_dim=args.fp_out_dim,
                                 conv_layers=args.conv_layers,
                                 concat_hidden=args.concat_hidden,
                                 fp_dropout_rate=args.fp_dropout_rate,
                                 net_hidden_dims=net_hidden_dims,
                                 class_num=class_num,
                                 sim_method=args.sim_method)

    # Set up the iterator.
    train_iter = SerialIterator(train, args.batchsize)
    val_iter = SerialIterator(val, args.batchsize, repeat=False, shuffle=False)

    metrics_fun = {'accuracy': F.binary_accuracy}
    classifier = Classifier(predictor,
                            lossfun=F.sigmoid_cross_entropy,
                            metrics_fun=metrics_fun,
                            device=args.gpu)

    # Set up the optimizer.
    optimizer = optimizers.Adam(alpha=args.learning_rate,
                                weight_decay_rate=args.weight_decay_rate)
    # optimizer = optimizers.Adam()
    # optimizer = optimizers.SGD(lr=args.learning_rate)
    optimizer.setup(classifier)

    # Set up the updater.
    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       device=args.gpu,
                                       converter=concat_mols)

    # Set up the trainer.
    print('Training...')
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(
        E.Evaluator(val_iter,
                    classifier,
                    device=args.gpu,
                    converter=concat_mols))

    train_eval_iter = SerialIterator(train,
                                     args.batchsize,
                                     repeat=False,
                                     shuffle=False)

    trainer.extend(
        AccuracyEvaluator(train_eval_iter,
                          classifier,
                          eval_func=predictor,
                          device=args.gpu,
                          converter=concat_mols,
                          name='train_acc',
                          pos_labels=1,
                          ignore_labels=-1,
                          raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        AccuracyEvaluator(val_iter,
                          classifier,
                          eval_func=predictor,
                          device=args.gpu,
                          converter=concat_mols,
                          name='val_acc',
                          pos_labels=1,
                          ignore_labels=-1))

    trainer.extend(
        ROCAUCEvaluator(train_eval_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='train_roc',
                        pos_labels=1,
                        ignore_labels=-1,
                        raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        ROCAUCEvaluator(val_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='val_roc',
                        pos_labels=1,
                        ignore_labels=-1))

    trainer.extend(
        PRCAUCEvaluator(train_eval_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='train_prc',
                        pos_labels=1,
                        ignore_labels=-1,
                        raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        PRCAUCEvaluator(val_iter,
                        classifier,
                        eval_func=predictor,
                        device=args.gpu,
                        converter=concat_mols,
                        name='val_prc',
                        pos_labels=1,
                        ignore_labels=-1))

    # trainer.extend(PrecisionEvaluator(
    #     train_eval_iter, classifier, eval_func=predictor,
    #     device=args.gpu, converter=concat_mols, name='train_p',
    #     pos_labels=1, ignore_labels=-1, raise_value_error=False))
    # # extension name='validation' is already used by `Evaluator`,
    # # instead extension name `val` is used.
    # trainer.extend(PrecisionEvaluator(
    #     val_iter, classifier, eval_func=predictor,
    #     device=args.gpu, converter=concat_mols, name='val_p',
    #     pos_labels=1, ignore_labels=-1))
    #
    # trainer.extend(RecallEvaluator(
    #     train_eval_iter, classifier, eval_func=predictor,
    #     device=args.gpu, converter=concat_mols, name='train_r',
    #     pos_labels=1, ignore_labels=-1, raise_value_error=False))
    # # extension name='validation' is already used by `Evaluator`,
    # # instead extension name `val` is used.
    # trainer.extend(RecallEvaluator(
    #     val_iter, classifier, eval_func=predictor,
    #     device=args.gpu, converter=concat_mols, name='val_r',
    #     pos_labels=1, ignore_labels=-1))

    trainer.extend(
        F1Evaluator(train_eval_iter,
                    classifier,
                    eval_func=predictor,
                    device=args.gpu,
                    converter=concat_mols,
                    name='train_f',
                    pos_labels=1,
                    ignore_labels=-1,
                    raise_value_error=False))
    # extension name='validation' is already used by `Evaluator`,
    # instead extension name `val` is used.
    trainer.extend(
        F1Evaluator(val_iter,
                    classifier,
                    eval_func=predictor,
                    device=args.gpu,
                    converter=concat_mols,
                    name='val_f',
                    pos_labels=1,
                    ignore_labels=-1))

    # apply shift strategy to learning rate every 10 epochs
    # trainer.extend(E.ExponentialShift('alpha', args.exp_shift_rate), trigger=(10, 'epoch'))
    trainer.extend(E.ExponentialShift('alpha', args.exp_shift_rate),
                   trigger=triggers.ManualScheduleTrigger([10, 20, 30, 40, 50],
                                                          'epoch'))
    # # observation of learning rate
    trainer.extend(E.observe_lr(), trigger=(1, 'iteration'))

    entries = [
        'epoch',
        'main/loss',
        'train_acc/main/accuracy',
        'train_roc/main/roc_auc',
        'train_prc/main/prc_auc',
        # 'train_p/main/precision', 'train_r/main/recall',
        'train_f/main/f1',
        'validation/main/loss',
        'val_acc/main/accuracy',
        'val_roc/main/roc_auc',
        'val_prc/main/prc_auc',
        # 'val_p/main/precision', 'val_r/main/recall',
        'val_f/main/f1',
        'lr',
        'elapsed_time'
    ]
    trainer.extend(E.PrintReport(entries=entries))
    trainer.extend(E.snapshot(), trigger=(args.epoch, 'epoch'))
    trainer.extend(E.LogReport())
    trainer.extend(E.ProgressBar())

    if args.resume:
        resume_path = os.path.join(args.out, args.resume)
        logging.info(
            'Resume training according to snapshot in {}'.format(resume_path))
        chainer.serializers.load_npz(resume_path, trainer)

    trainer.run()

    # Save the regressor's parameters.
    model_path = os.path.join(args.out, args.model_filename)
    print('Saving the trained models to {}...'.format(model_path))
    classifier.save_pickle(model_path, protocol=args.protocol)