Ejemplo n.º 1
0
def main(args):
    start_time = datetime.now().strftime('%Y%m%d_%H_%M_%S')
    dest = "../result/" + start_time
    os.makedirs(dest)
    abs_dest = os.path.abspath(dest)
    with open(os.path.join(dest, "settings.json"), "w") as fo:
        fo.write(json.dumps(vars(args), sort_keys=True, indent=4))
        print(json.dumps(vars(args), sort_keys=True, indent=4), file=sys.stderr)

    # load data
    data_processor = DataProcessor(args.data, args.vocab, args.test, args.max_length)
    data_processor.prepare_dataset()
    data_processor.compute_max_length()
    train_data = data_processor.train_data
    dev_data = data_processor.dev_data
    test_data = data_processor.test_data


    # create model
    vocab = data_processor.vocab
    embed_dim = args.dim
    x1s_len = data_processor.max_x1s_len
    x2s_len = data_processor.max_x2s_len
    model_type = args.model_type
    if args.model_type == 'ABCNN1' or args.model_type == 'ABCNN3':
        input_channel = 2
    else:
        input_channel = 1
    cnn = ABCNN(n_vocab=len(vocab), embed_dim=embed_dim, input_channel=input_channel,
               output_channel=50, x1s_len=x1s_len, x2s_len=x2s_len, model_type=model_type, single_attention_mat=args.single_attention_mat)  # ABCNNはoutput = 50固定らしいが.
    model = Classifier(cnn, lossfun=sigmoid_cross_entropy,
                         accfun=binary_accuracy)
    if args.glove:
        cnn.load_glove_embeddings(args.glove_path, data_processor.vocab)
    if args.word2vec:
        cnn.load_word2vec_embeddings(args.word2vec_path, data_processor.vocab)
    if args.gpu >= 0:
        cuda.get_device(args.gpu).use()
        model.to_gpu()
    cnn.set_pad_embedding_to_zero(data_processor.vocab)

    # setup optimizer
    optimizer = O.AdaGrad(args.lr)
    optimizer.setup(model)
    # do not use weight decay for embeddings
    decay_params = {name: 1 for name,
                    variable in model.namedparams() if "embed" not in name}
    optimizer.add_hook(SelectiveWeightDecay(
        rate=args.decay, decay_params=decay_params))

    train_iter = chainer.iterators.SerialIterator(train_data, args.batchsize)
    print(train_iter._order)

    dev_train_iter = chainer.iterators.SerialIterator(
        train_data, args.batchsize, repeat=False)
    if args.use_test_data:
        dev_iter = DevIterator(test_data, data_processor.n_test)
    else:
        dev_iter = DevIterator(dev_data, data_processor.n_dev)

    x1s_len = np.array([cnn.x1s_len], dtype=np.int32)
    x2s_len = np.array([cnn.x2s_len], dtype=np.int32)
    updater = ABCNNUpdater(train_iter, optimizer, converter=BCNN.util.concat_examples, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=abs_dest)

    # setup evaluation
    eval_predictor = model.copy().predictor
    eval_predictor.train = False
    iters = {"train": dev_train_iter, "dev": dev_iter}
    trainer.extend(WikiQAEvaluator(
        iters, eval_predictor, converter=BCNN.util.concat_examples, device=args.gpu))

    # extentions...
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss', 'validation/main/map', 'validation/main/mrr', 'validation/main/svm_map', 'validation/main/svm_mrr']))
    trainer.extend(extensions.ProgressBar(update_interval=10))
    # take a shapshot when the model achieves highest accuracy in dev set
    trainer.extend(extensions.snapshot_object(
        model, 'model_epoch_{.updater.epoch}',
        trigger=chainer.training.triggers.MaxValueTrigger('validation/main/map')))
    # trainer.extend(extensions.ExponentialShift("lr", 0.5, optimizer=optimizer),
    #                trigger=chainer.training.triggers.MaxValueTrigger("validation/main/map"))
    trainer.run()
Ejemplo n.º 2
0
def main():
    # Supported preprocessing/network list
    method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn']
    label_names = D.get_tox21_label_names()
    iterator_type = ['serial', 'balanced']

    parser = argparse.ArgumentParser(
        description='Multitask Learning with Tox21.')
    parser.add_argument('--method',
                        '-m',
                        type=str,
                        choices=method_list,
                        default='nfp',
                        help='graph convolution model to use '
                        'as a predictor.')
    parser.add_argument('--label',
                        '-l',
                        type=str,
                        choices=label_names,
                        default='',
                        help='target label for logistic '
                        'regression. Use all labels if this option '
                        'is not specified.')
    parser.add_argument('--iterator-type',
                        type=str,
                        choices=iterator_type,
                        default='serial',
                        help='iterator type. If `balanced` '
                        'is specified, data is sampled to take same number of'
                        'positive/negative labels during training.')
    parser.add_argument('--eval-mode',
                        type=int,
                        default=1,
                        help='Evaluation mode.'
                        '0: only binary_accuracy is calculated.'
                        '1: binary_accuracy and ROC-AUC score is calculated')
    parser.add_argument('--conv-layers',
                        '-c',
                        type=int,
                        default=4,
                        help='number of convolution layers')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=32,
                        help='batch size')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID to use. Negative value indicates '
                        'not to use GPU and to run the code in CPU.')
    parser.add_argument('--out',
                        '-o',
                        type=str,
                        default='result',
                        help='path to output directory')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=10,
                        help='number of epochs')
    parser.add_argument('--unit-num',
                        '-u',
                        type=int,
                        default=16,
                        help='number of units in one layer of the model')
    parser.add_argument('--resume',
                        '-r',
                        type=str,
                        default='',
                        help='path to a trainer snapshot')
    parser.add_argument('--frequency',
                        '-f',
                        type=int,
                        default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--protocol',
                        type=int,
                        default=2,
                        help='protocol version for pickle')
    parser.add_argument('--model-filename',
                        type=str,
                        default='classifier.pkl',
                        help='file name for pickled model')
    parser.add_argument('--num-data',
                        type=int,
                        default=-1,
                        help='Number of data to be parsed from parser.'
                        '-1 indicates to parse all data.')
    args = parser.parse_args()

    method = args.method
    if args.label:
        labels = args.label
        class_num = len(labels) if isinstance(labels, list) else 1
    else:
        labels = None
        class_num = len(label_names)

    # Dataset preparation
    train, val, _ = data.load_dataset(method, labels, num_data=args.num_data)

    # Network
    predictor_ = predictor.build_predictor(method, args.unit_num,
                                           args.conv_layers, class_num)

    iterator_type = args.iterator_type
    if iterator_type == 'serial':
        train_iter = I.SerialIterator(train, args.batchsize)
    elif iterator_type == 'balanced':
        if class_num > 1:
            raise ValueError('BalancedSerialIterator can be used with only one'
                             'label classification, please specify label to'
                             'be predicted by --label option.')
        train_iter = BalancedSerialIterator(train,
                                            args.batchsize,
                                            train.features[:, -1],
                                            ignore_labels=-1)
        train_iter.show_label_stats()
    else:
        raise ValueError('Invalid iterator type {}'.format(iterator_type))
    val_iter = I.SerialIterator(val,
                                args.batchsize,
                                repeat=False,
                                shuffle=False)

    classifier = Classifier(predictor_,
                            lossfun=F.sigmoid_cross_entropy,
                            metrics_fun=F.binary_accuracy,
                            device=args.gpu)

    optimizer = O.Adam()
    optimizer.setup(classifier)

    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       device=args.gpu,
                                       converter=concat_mols)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(
        E.Evaluator(val_iter,
                    classifier,
                    device=args.gpu,
                    converter=concat_mols))
    trainer.extend(E.LogReport())

    eval_mode = args.eval_mode
    if eval_mode == 0:
        trainer.extend(
            E.PrintReport([
                'epoch', 'main/loss', 'main/accuracy', 'validation/main/loss',
                'validation/main/accuracy', 'elapsed_time'
            ]))
    elif eval_mode == 1:
        train_eval_iter = I.SerialIterator(train,
                                           args.batchsize,
                                           repeat=False,
                                           shuffle=False)
        trainer.extend(
            ROCAUCEvaluator(train_eval_iter,
                            classifier,
                            eval_func=predictor_,
                            device=args.gpu,
                            converter=concat_mols,
                            name='train',
                            pos_labels=1,
                            ignore_labels=-1,
                            raise_value_error=False))
        # extension name='validation' is already used by `Evaluator`,
        # instead extension name `val` is used.
        trainer.extend(
            ROCAUCEvaluator(val_iter,
                            classifier,
                            eval_func=predictor_,
                            device=args.gpu,
                            converter=concat_mols,
                            name='val',
                            pos_labels=1,
                            ignore_labels=-1))
        trainer.extend(
            E.PrintReport([
                'epoch', 'main/loss', 'main/accuracy', 'train/main/roc_auc',
                'validation/main/loss', 'validation/main/accuracy',
                'val/main/roc_auc', 'elapsed_time'
            ]))
    else:
        raise ValueError('Invalid accfun_mode {}'.format(eval_mode))
    trainer.extend(E.ProgressBar(update_interval=10))
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(E.snapshot(), trigger=(frequency, 'epoch'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()

    config = {
        'method': args.method,
        'conv_layers': args.conv_layers,
        'unit_num': args.unit_num,
        'labels': args.label
    }

    with open(os.path.join(args.out, 'config.json'), 'w') as o:
        o.write(json.dumps(config))

    classifier.save_pickle(os.path.join(args.out, args.model_filename),
                           protocol=args.protocol)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--model',
                        '-m',
                        choices=['vgg16', 'resnet50', 'resnet101'],
                        default='resnet50',
                        help='base model')
    parser.add_argument('--pooling-func',
                        '-p',
                        choices=['pooling', 'align', 'resize'],
                        default='align',
                        help='pooling function')
    parser.add_argument('--gpu', '-g', type=int, help='gpu id')
    parser.add_argument('--multi-node',
                        '-n',
                        action='store_true',
                        help='use multi node')
    parser.add_argument('--roi-size',
                        '-r',
                        type=int,
                        default=7,
                        help='roi size')
    args = parser.parse_args()

    if args.multi_node:
        import chainermn
        comm = chainermn.create_communicator('hierarchical')
        device = comm.intra_rank

        args.n_node = comm.inter_size
        args.n_gpu = comm.size
        chainer.cuda.get_device_from_id(device).use()
    else:
        args.n_node = 1
        args.n_gpu = 1
        chainer.cuda.get_device_from_id(args.gpu).use()
        device = args.gpu

    args.seed = 0
    now = datetime.datetime.now()
    args.timestamp = now.isoformat()
    args.out = osp.join(here, 'logs', now.strftime('%Y%m%d_%H%M%S'))

    # 0.00125 * 8 = 0.01  in original
    args.batch_size = 1 * args.n_gpu
    args.lr = 0.00125 * args.batch_size
    args.weight_decay = 0.0001

    # (180e3 * 8) / len(coco_trainval)
    args.max_epoch = (180e3 * 8) / 118287
    # lr / 10 at 120k iteration with
    # 160k iteration * 16 batchsize in original
    args.step_size = [(120e3 / 180e3) * args.max_epoch,
                      (160e3 / 180e3) * args.max_epoch]

    random.seed(args.seed)
    np.random.seed(args.seed)

    args.dataset = 'voc'
    train_data = mrcnn.datasets.SBDInstanceSegmentationDataset('train')
    test_data = mrcnn.datasets.SBDInstanceSegmentationDataset('val')
    fg_class_names = train_data.class_names

    if args.pooling_func == 'align':
        pooling_func = mrcnn.functions.roi_align_2d
    elif args.pooling_func == 'pooling':
        pooling_func = chainer.functions.roi_pooling_2d
    elif args.pooling_func == 'resize':
        pooling_func = mrcnn.functions.crop_and_resize
    else:
        raise ValueError

    if args.model == 'vgg16':
        mask_rcnn = mrcnn.models.MaskRCNNVGG16(
            n_fg_class=len(fg_class_names),
            pretrained_model='imagenet',
            pooling_func=pooling_func,
            roi_size=args.roi_size,
        )
    elif args.model in ['resnet50', 'resnet101']:
        n_layers = int(args.model.lstrip('resnet'))
        mask_rcnn = mrcnn.models.MaskRCNNResNet(
            n_layers=n_layers,
            n_fg_class=len(fg_class_names),
            pretrained_model='imagenet',
            pooling_func=pooling_func,
            roi_size=args.roi_size,
        )
    else:
        raise ValueError
    mask_rcnn.use_preset('evaluate')
    model = mrcnn.models.MaskRCNNTrainChain(mask_rcnn)
    if args.multi_node or args.gpu >= 0:
        model.to_gpu()

    optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9)
    if args.multi_node:
        optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(rate=args.weight_decay))

    if args.model in ['resnet50', 'resnet101']:
        model.mask_rcnn.extractor.mode = 'res3+'
        mask_rcnn.extractor.conv1.disable_update()
        mask_rcnn.extractor.bn1.disable_update()
        mask_rcnn.extractor.res2.disable_update()

    train_data = chainer.datasets.TransformDataset(
        train_data, mrcnn.datasets.MaskRCNNTransform(mask_rcnn))
    test_data = chainer.datasets.TransformDataset(
        test_data, mrcnn.datasets.MaskRCNNTransform(mask_rcnn, train=False))
    if args.multi_node:
        if comm.rank != 0:
            train_data = None
            test_data = None
        train_data = chainermn.scatter_dataset(train_data, comm, shuffle=True)
        test_data = chainermn.scatter_dataset(test_data, comm)

    train_iter = chainer.iterators.MultiprocessIterator(train_data,
                                                        batch_size=1,
                                                        n_prefetch=4,
                                                        shared_mem=10**8)
    test_iter = chainer.iterators.MultiprocessIterator(test_data,
                                                       batch_size=1,
                                                       n_prefetch=4,
                                                       shared_mem=10**8,
                                                       repeat=False,
                                                       shuffle=False)

    updater = chainer.training.updater.StandardUpdater(
        train_iter,
        optimizer,
        device=device,
        converter=mrcnn.datasets.concat_examples)

    trainer = training.Trainer(updater, (args.max_epoch, 'epoch'),
                               out=args.out)

    trainer.extend(extensions.ExponentialShift('lr', 0.1),
                   trigger=training.triggers.ManualScheduleTrigger(
                       args.step_size, 'epoch'))

    eval_interval = 1, 'epoch'
    log_interval = 20, 'iteration'
    plot_interval = 0.1, 'epoch'
    print_interval = 20, 'iteration'

    evaluator = mrcnn.extensions.InstanceSegmentationVOCEvaluator(
        test_iter,
        model.mask_rcnn,
        device=device,
        use_07_metric=True,
        label_names=fg_class_names)
    if args.multi_node:
        evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    trainer.extend(evaluator, trigger=eval_interval)

    if not args.multi_node or comm.rank == 0:
        trainer.extend(extensions.snapshot_object(model.mask_rcnn,
                                                  'snapshot_model.npz'),
                       trigger=training.triggers.MaxValueTrigger(
                           'validation/main/map', eval_interval))
        args.git_hash = mrcnn.utils.git_hash()
        args.hostname = socket.gethostname()
        trainer.extend(fcn.extensions.ParamsReport(args.__dict__))
        trainer.extend(mrcnn.extensions.InstanceSegmentationVisReport(
            test_iter, model.mask_rcnn, label_names=fg_class_names),
                       trigger=eval_interval)
        trainer.extend(chainer.training.extensions.observe_lr(),
                       trigger=log_interval)
        trainer.extend(extensions.LogReport(trigger=log_interval))
        trainer.extend(extensions.PrintReport([
            'iteration', 'epoch', 'elapsed_time', 'lr', 'main/loss',
            'main/roi_loc_loss', 'main/roi_cls_loss', 'main/roi_mask_loss',
            'main/rpn_loc_loss', 'main/rpn_cls_loss', 'validation/main/map'
        ]),
                       trigger=print_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))

        # plot
        assert extensions.PlotReport.available()
        trainer.extend(
            extensions.PlotReport([
                'main/loss', 'main/roi_loc_loss', 'main/roi_cls_loss',
                'main/roi_mask_loss', 'main/rpn_loc_loss', 'main/rpn_cls_loss'
            ],
                                  file_name='loss.png',
                                  trigger=plot_interval),
            trigger=plot_interval,
        )
        trainer.extend(
            extensions.PlotReport(['validation/main/map'],
                                  file_name='accuracy.png',
                                  trigger=plot_interval),
            trigger=eval_interval,
        )

        trainer.extend(extensions.dump_graph('main/loss'))

    trainer.run()
Ejemplo n.º 4
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)
    if args.num_encs > 1:
        args = format_mulenc_args(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim_list = [
        int(valid_json[utts[0]]['input'][i]['shape'][-1])
        for i in range(args.num_encs)
    ]
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    for i in range(args.num_encs):
        logging.info('stream{}: input dims : {}'.format(i + 1, idim_list[i]))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    if (args.enc_init is not None
            or args.dec_init is not None) and args.num_encs == 1:
        model = load_trained_modules(idim_list[0], odim, args)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim_list[0] if args.num_encs == 1 else idim_list,
                            odim, args)
    assert isinstance(model, ASRInterface)

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps(
                (idim_list[0] if args.num_encs == 1 else idim_list, odim,
                 vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                'batch size is automatically increased (%d -> %d)' %
                (args.batch_size, args.batch_size * args.ngpu))
            #args.batch_size *= args.ngpu
        if args.num_encs > 1:
            # TODO(ruizhili): implement data parallel for multi-encoder setup.
            raise NotImplementedError(
                "Data parallel is not supported for multi-encoder setup.")

    # set torch device
    #device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    device = torch.device("cuda", args.local_rank)
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    #model = model.to(device=device, dtype=dtype)
    model = model.to(device, torch.float32)
    model = DistributedDataParallel(model,
                                    device_ids=[args.local_rank],
                                    output_device=args.local_rank)
    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == 'noam':
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    if args.num_encs == 1:
        #converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
        converter = CustomConverter(dtype=dtype)
    else:
        converter = CustomConverterMulEnc([i[0] for i in model.subsample_list],
                                          dtype=dtype)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_dataset = TransformDataset(train,
                                     lambda data: converter([load_tr(data)]))
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)
    train_iter = {
        'main':
        ChainerDataLoader(dataset=train_dataset,
                          batch_size=1,
                          collate_fn=lambda x: x[0],
                          sampler=train_sampler)
    }
    valid_iter = {
        'main':
        ChainerDataLoader(dataset=TransformDataset(
            valid, lambda data: converter([load_cv(data)])),
                          batch_size=1,
                          shuffle=False,
                          collate_fn=lambda x: x[0],
                          num_workers=args.n_iter_processes)
    }

    # train_iter = {'main': ChainerDataLoader(
    #     dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
    #     batch_size=1, num_workers=args.n_iter_processes,
    #     shuffle=not use_sortagrad, collate_fn=lambda x: x[0])}
    # valid_iter = {'main': ChainerDataLoader(
    #     dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
    #     batch_size=1, shuffle=False, collate_fn=lambda x: x[0],
    #     num_workers=args.n_iter_processes)}

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            device,
                            args.ngpu,
                            args.grad_noise,
                            args.accum_grad,
                            use_apex=use_apex,
                            local_rank=args.local_rank)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)
    if args.local_rank == 0:
        # Evaluate the model with the test dataset for each epoch
        if args.save_interval_iters > 0:
            trainer.extend(CustomEvaluator(model, valid_iter, reporter, device,
                                           args.ngpu),
                           trigger=(args.save_interval_iters, 'iteration'))
        else:
            trainer.extend(
                CustomEvaluator(model, valid_iter, reporter, device,
                                args.ngpu))

        # Save attention weight each epoch
        if args.num_save_attention > 0 and args.mtlalpha != 1.0:
            data = sorted(list(valid_json.items())[:args.num_save_attention],
                          key=lambda x: int(x[1]['input'][0]['shape'][1]),
                          reverse=True)
            if hasattr(model, "module"):
                att_vis_fn = model.module.calculate_all_attentions
                plot_class = model.module.attention_plot_class
            else:
                att_vis_fn = model.calculate_all_attentions
                plot_class = model.attention_plot_class
            att_reporter = plot_class(att_vis_fn,
                                      data,
                                      args.outdir + "/att_ws",
                                      converter=converter,
                                      transform=load_cv,
                                      device=device)
            trainer.extend(att_reporter, trigger=(1, 'epoch'))
        else:
            att_reporter = None

        # Make a plot for training and validation values
        if args.num_encs > 1:
            report_keys_loss_ctc = [
                'main/loss_ctc{}'.format(i + 1) for i in range(model.num_encs)
            ] + [
                'validation/main/loss_ctc{}'.format(i + 1)
                for i in range(model.num_encs)
            ]
            report_keys_cer_ctc = [
                'main/cer_ctc{}'.format(i + 1) for i in range(model.num_encs)
            ] + [
                'validation/main/cer_ctc{}'.format(i + 1)
                for i in range(model.num_encs)
            ]
        trainer.extend(
            extensions.PlotReport([
                'main/loss', 'validation/main/loss', 'main/loss_ctc',
                'validation/main/loss_ctc', 'main/loss_att',
                'validation/main/loss_att'
            ] + ([] if args.num_encs == 1 else report_keys_loss_ctc),
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(['main/acc', 'validation/main/acc'],
                                  'epoch',
                                  file_name='acc.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/cer_ctc', 'validation/main/cer_ctc'] +
                ([] if args.num_encs == 1 else report_keys_loss_ctc),
                'epoch',
                file_name='cer.png'))

        # Save best models
        trainer.extend(
            snapshot_object(model, 'model.loss.best'),
            trigger=training.triggers.MinValueTrigger('validation/main/loss'))
        if mtl_mode != 'ctc':
            trainer.extend(snapshot_object(model, 'model.acc.best'),
                           trigger=training.triggers.MaxValueTrigger(
                               'validation/main/acc'))

        # save snapshot which contains model and optimizer states
        if args.save_interval_iters > 0:
            trainer.extend(
                torch_snapshot(filename='snapshot.iter.{.updater.iteration}'),
                trigger=(args.save_interval_iters, 'iteration'))
        else:
            trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
    if args.local_rank == 0:
        # Write a log of evaluation statistics for each epoch
        trainer.extend(
            extensions.LogReport(trigger=(args.report_interval_iters,
                                          'iteration')))
        report_keys = [
            'epoch', 'iteration', 'main/loss', 'main/loss_ctc',
            'main/loss_att', 'validation/main/loss',
            'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc',
            'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc',
            'elapsed_time'
        ] + ([] if args.num_encs == 1 else report_keys_cer_ctc +
             report_keys_loss_ctc)
        if args.opt == 'adadelta':
            trainer.extend(extensions.observe_value(
                'eps', lambda trainer: trainer.updater.get_optimizer('main').
                param_groups[0]["eps"]),
                           trigger=(args.report_interval_iters, 'iteration'))
            report_keys.append('eps')
        if args.report_cer:
            report_keys.append('validation/main/cer')
        if args.report_wer:
            report_keys.append('validation/main/wer')
        trainer.extend(extensions.PrintReport(report_keys),
                       trigger=(args.report_interval_iters, 'iteration'))

        trainer.extend(
            extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)
    if args.local_rank == 0:
        if args.tensorboard_dir is not None and args.tensorboard_dir != "":
            trainer.extend(TensorboardLogger(
                SummaryWriter(args.tensorboard_dir), att_reporter),
                           trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Ejemplo n.º 5
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    set_deterministic_chainer(args)

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # check attention type
    if args.atype not in ['noatt', 'dot', 'location']:
        raise NotImplementedError('chainer supports only noatt, dot, and location attention.')

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    logging.info('import model module: ' + args.model_module)
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args, flag_return=False)
    assert isinstance(model, ASRInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(json.dumps((idim, odim, vars(args)),
                           indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    ngpu = args.ngpu
    if ngpu == 1:
        gpu_id = 0
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU
        logging.info('single gpu calculation.')
    elif ngpu > 1:
        gpu_id = 0
        devices = {'main': gpu_id}
        for gid in six.moves.xrange(1, ngpu):
            devices['sub_%d' % gid] = gid
        logging.info('multi gpu calculation (#gpus = %d).' % ngpu)
        logging.info('batch size is automatically increased (%d -> %d)' % (
            args.batch_size, args.batch_size * args.ngpu))
    else:
        gpu_id = -1
        logging.info('cpu calculation')

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
    elif args.opt == 'adam':
        optimizer = chainer.optimizers.Adam()
    elif args.opt == 'noam':
        optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9)
    else:
        raise NotImplementedError('args.opt={}'.format(args.opt))

    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))

    # Setup Training Extensions
    if 'transformer' in args.model_module:
        from espnet.nets.chainer_backend.transformer.training import CustomConverter
        from espnet.nets.chainer_backend.transformer.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.transformer.training import CustomUpdater
    else:
        from espnet.nets.chainer_backend.rnn.training import CustomConverter
        from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.rnn.training import CustomUpdater

    # Setup a converter
    converter = CustomConverter(subsampling_factor=model.subsample[0])

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # set up training iterator and updater
    load_tr = LoadInputsAndTargets(
        mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    accum_grad = args.accum_grad
    if ngpu <= 1:
        # make minibatch list (variable length)
        train = make_batchset(train_json, args.batch_size,
                              args.maxlen_in, args.maxlen_out, args.minibatches,
                              min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                              shortest_first=use_sortagrad,
                              count=args.batch_count,
                              batch_bins=args.batch_bins,
                              batch_frames_in=args.batch_frames_in,
                              batch_frames_out=args.batch_frames_out,
                              batch_frames_inout=args.batch_frames_inout)
        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [ToggleableShufflingMultiprocessIterator(
                TransformDataset(train, load_tr),
                batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20,
                shuffle=not use_sortagrad)]
        else:
            train_iters = [ToggleableShufflingSerialIterator(
                TransformDataset(train, load_tr),
                batch_size=1, shuffle=not use_sortagrad)]

        # set up updater
        updater = CustomUpdater(
            train_iters[0], optimizer, converter=converter, device=gpu_id, accum_grad=accum_grad)
    else:
        if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
            raise NotImplementedError("--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu")
        # set up minibatches
        train_subsets = []
        for gid in six.moves.xrange(ngpu):
            # make subset
            train_json_subset = {k: v for i, (k, v) in enumerate(train_json.items())
                                 if i % ngpu == gid}
            # make minibatch list (variable length)
            train_subsets += [make_batchset(train_json_subset, args.batch_size,
                                            args.maxlen_in, args.maxlen_out, args.minibatches)]

        # each subset must have same length for MultiprocessParallelUpdater
        maxlen = max([len(train_subset) for train_subset in train_subsets])
        for train_subset in train_subsets:
            if maxlen != len(train_subset):
                for i in six.moves.xrange(maxlen - len(train_subset)):
                    train_subset += [train_subset[i]]

        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [ToggleableShufflingMultiprocessIterator(
                TransformDataset(train_subsets[gid], load_tr),
                batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20,
                shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)]
        else:
            train_iters = [ToggleableShufflingSerialIterator(
                TransformDataset(train_subsets[gid], load_tr),
                batch_size=1, shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)]

        # set up updater
        updater = CustomParallelUpdater(
            train_iters, optimizer, converter=converter, devices=devices)

    # Set up a trainer
    trainer = training.Trainer(
        updater, (args.epochs, 'epoch'), out=args.outdir)

    if use_sortagrad:
        trainer.extend(ShufflingEnabler(train_iters),
                       trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch'))
    if args.opt == 'noam':
        from espnet.nets.chainer_backend.transformer.training import VaswaniRule
        trainer.extend(VaswaniRule('alpha', d=args.adim, warmup_steps=args.transformer_warmup_steps,
                                   scale=args.transformer_lr), trigger=(1, 'iteration'))
    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # set up validation iterator
    valid = make_batchset(valid_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout)

    if args.n_iter_processes > 0:
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1, repeat=False, shuffle=False,
            n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
    else:
        valid_iter = chainer.iterators.SerialIterator(
            TransformDataset(valid, load_cv),
            batch_size=1, repeat=False, shuffle=False)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(BaseEvaluator(
        valid_iter, model, converter=converter, device=gpu_id))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        logging.info('Using custom PlotAttentionReport')
        att_reporter = plot_class(
            att_vis_fn, data, args.outdir + "/att_ws",
            converter=converter, transform=load_cv, device=gpu_id)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Take a snapshot for each specified epoch
    trainer.extend(extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'), trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss',
                                          'main/loss_ctc', 'validation/main/loss_ctc',
                                          'main/loss_att', 'validation/main/loss_att'],
                                         'epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'],
                                         'epoch', file_name='acc.png'))

    # Save best models
    trainer.extend(extensions.snapshot_object(model, 'model.loss.best'),
                   trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(extensions.snapshot_object(model, 'model.acc.best'),
                       trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(args.report_interval_iters, 'iteration')))
    report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
                   'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att',
                   'main/acc', 'validation/main/acc', 'elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').eps),
            trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(
        report_keys), trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=(args.report_interval_iters, 'iteration'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Ejemplo n.º 6
0
def train(args):
    """Train with the given args.

    :param Namespace args: The program arguments
    :param type model_class: LMInterface class for training
    """
    model_class = dynamic_import_lm(args.model_module, args.backend)
    assert issubclass(model_class,
                      LMInterface), "model should implement LMInterface"
    # display torch version
    logging.info('torch version = ' + torch.__version__)

    set_deterministic_pytorch(args)

    # check cuda and cudnn availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get special label ids
    unk = args.char_list_dict['<unk>']
    eos = args.char_list_dict['<eos>']
    # read tokens as a sequence of sentences
    val, n_val_tokens, n_val_oovs = load_dataset(args.valid_label,
                                                 args.char_list_dict,
                                                 args.dump_hdf5_path)
    train, n_train_tokens, n_train_oovs = load_dataset(args.train_label,
                                                       args.char_list_dict,
                                                       args.dump_hdf5_path)
    logging.info('#vocab = ' + str(args.n_vocab))
    logging.info('#sentences in the training data = ' + str(len(train)))
    logging.info('#tokens in the training data = ' + str(n_train_tokens))
    logging.info('oov rate in the training data = %.2f %%' %
                 (n_train_oovs / n_train_tokens * 100))
    logging.info('#sentences in the validation data = ' + str(len(val)))
    logging.info('#tokens in the validation data = ' + str(n_val_tokens))
    logging.info('oov rate in the validation data = %.2f %%' %
                 (n_val_oovs / n_val_tokens * 100))

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # Create the dataset iterators
    batch_size = args.batchsize * max(args.ngpu, 1)
    if batch_size > args.batchsize:
        logging.info(
            f'batch size is automatically increased ({args.batchsize} -> {batch_size})'
        )
    train_iter = ParallelSentenceIterator(train,
                                          batch_size,
                                          max_length=args.maxlen,
                                          sos=eos,
                                          eos=eos,
                                          shuffle=not use_sortagrad)
    val_iter = ParallelSentenceIterator(val,
                                        batch_size,
                                        max_length=args.maxlen,
                                        sos=eos,
                                        eos=eos,
                                        repeat=False)
    logging.info('#iterations per epoch = ' +
                 str(len(train_iter.batch_indices)))
    logging.info('#total iterations = ' +
                 str(args.epoch * len(train_iter.batch_indices)))
    # Prepare an RNNLM model
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model_class(args.n_vocab, args).to(dtype=dtype)
    if args.ngpu > 0:
        model.to("cuda")
        gpu_id = list(range(args.ngpu))
    else:
        gpu_id = [-1]

    # Save model conf to json
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps(vars(args),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))

    # Set up an optimizer
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    reporter = Reporter()
    setattr(model, "reporter", reporter)
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    updater = BPTTUpdater(train_iter,
                          model,
                          optimizer,
                          gpu_id,
                          gradclip=args.gradclip,
                          use_apex=use_apex)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir)
    trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
    trainer.extend(
        extensions.LogReport(postprocess=compute_perplexity,
                             trigger=(args.report_interval_iters,
                                      'iteration')))
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'perplexity', 'val_perplexity',
        'elapsed_time'
    ]),
                   trigger=(args.report_interval_iters, 'iteration'))
    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    # Save best models
    trainer.extend(torch_snapshot(filename='snapshot.ep.{.updater.epoch}'))
    trainer.extend(snapshot_object(model, 'rnnlm.model.{.updater.epoch}'))
    # T.Hori: MinValueTrigger should be used, but it fails when resuming
    trainer.extend(
        MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model'))

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch,
                     'epoch'))
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    set_early_stop(trainer, args, is_lm=True)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer),
                       trigger=(args.report_interval_iters, 'iteration'))

    trainer.run()
    check_early_stop(trainer, args.epoch)

    # compute perplexity for test set
    if args.test_label:
        logging.info('test the best model')
        torch_load(args.outdir + '/rnnlm.model.best', model)
        test = read_tokens(args.test_label, args.char_list_dict)
        n_test_tokens, n_test_oovs = count_tokens(test, unk)
        logging.info('#sentences in the test data = ' + str(len(test)))
        logging.info('#tokens in the test data = ' + str(n_test_tokens))
        logging.info('oov rate in the test data = %.2f %%' %
                     (n_test_oovs / n_test_tokens * 100))
        test_iter = ParallelSentenceIterator(test,
                                             batch_size,
                                             max_length=args.maxlen,
                                             sos=eos,
                                             eos=eos,
                                             repeat=False)
        evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
        result = evaluator()
        compute_perplexity(result)
        logging.info(f"test perplexity: {result['perplexity']}")
Ejemplo n.º 7
0
def main(params):
    print("")
    print('# gpu: {}'.format(params["gpu"]))
    print('# unit: {}'.format(params["unit"]))
    print('# batch-size: {}'.format(params["batchsize"]))
    print('# epoch: {}'.format(params["epoch"]))
    print('# number of category: {}'.format(params["output_dimensions"]))
    print('# embedding dimension: {}'.format(params["embedding_dimensions"]))
    print('# current layer: {}'.format(params["current_depth"]))
    print('# model-type: {}'.format(params["model_type"]))
    print('')

    f = open('./CNN/LOG/configuration_' + params["current_depth"] + '.txt',
             'w')
    f.write('# gpu: {}'.format(params["gpu"]) + "\n")
    f.write('# unit: {}'.format(params["unit"]) + "\n")
    f.write('# batch-size: {}'.format(params["batchsize"]) + "\n")
    f.write('# epoch: {}'.format(params["epoch"]) + "\n")
    f.write('# number of category: {}'.format(params["output_dimensions"]) +
            "\n")
    f.write(
        '# embedding dimension: {}'.format(params["embedding_dimensions"]) +
        "\n")
    f.write('# current layer: {}'.format(params["current_depth"]) + "\n")
    f.write('# model-type: {}'.format(params["model_type"]) + "\n")
    f.write("\n")
    f.close()

    embedding_weight = params["embedding_weight"]
    embedding_dimensions = params["embedding_dimensions"]
    input_data = params["input_data"]
    x_train = input_data['x_trn']
    x_val = input_data['x_val']
    y_train = input_data['y_trn']
    y_val = input_data['y_val']

    cnn_params = {
        "cudnn": USE_CUDNN,
        "out_channels": params["out_channels"],
        "row_dim": embedding_dimensions,
        "batch_size": params["batchsize"],
        "hidden_dim": params["unit"],
        "n_classes": params["output_dimensions"],
        "embedding_weight": embedding_weight,
    }
    if params["fine_tuning"] == 0:
        cnn_params['mode'] = 'scratch'
    elif params["fine_tuning"] == 1:
        cnn_params['mode'] = 'fine-tuning'
        cnn_params['load_param_node_name'] = params['upper_depth']

    if params["model_type"] == "XML-CNN":
        model = xml_cnn_model.CNN(**cnn_params)
    else:
        model = cnn_model.CNN(**cnn_params)

    if params["gpu"] >= 0:
        chainer.cuda.get_device_from_id(params["gpu"]).use()
        model.to_gpu()

    # Learning CNN by training and validation data
    # =========================================================

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train = tuple_dataset.TupleDataset(x_train, y_train)
    val = tuple_dataset.TupleDataset(x_val, y_val)

    train_iter = chainer.iterators.SerialIterator(train,
                                                  params["batchsize"],
                                                  repeat=True,
                                                  shuffle=False)
    val_iter = chainer.iterators.SerialIterator(val,
                                                params["batchsize"],
                                                repeat=False,
                                                shuffle=False)

    # The setting of Early stopping validation refers to a loss value (validation/main/loss) obtained by validation data
    # =========================================================
    stop_trigger = training.triggers.EarlyStoppingTrigger(
        monitor='validation/main/loss', max_trigger=(params["epoch"], 'epoch'))

    updater = MyUpdater(train_iter,
                        optimizer,
                        params["output_dimensions"],
                        device=params["gpu"])
    trainer = training.Trainer(updater, stop_trigger, out='./CNN/')

    trainer.extend(
        MyEvaluator(val_iter,
                    model,
                    class_dim=params["output_dimensions"],
                    device=params["gpu"]))
    trainer.extend(extensions.dump_graph('main/loss'))

    trainer.extend(extensions.snapshot_object(
        model, 'parameters_for_multi_label_model_' + params["current_depth"] +
        '.npz'),
                   trigger=training.triggers.MinValueTrigger(
                       'validation/main/loss', trigger=(1, 'epoch')))

    trainer.extend(
        extensions.LogReport(log_name='LOG/log_' + params["current_depth"] +
                             ".txt",
                             trigger=(1, 'epoch')))

    trainer.extend(
        extensions.PrintReport(
            ['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
    trainer.extend(extensions.ProgressBar())

    trainer.extend(
        extensions.PlotReport(['main/loss', 'validation/main/loss'],
                              'epoch',
                              file_name='LOG/loss_' + params["current_depth"] +
                              '.png'))

    trainer.run()

    filename = 'parameters_for_multi_label_model_' + params[
        "current_depth"] + '.npz'
    src = './CNN/'
    dst = './CNN/PARAMS'
    shutil.move(os.path.join(src, filename), os.path.join(dst, filename))

    # Prediction process for test data.
    # =========================================================
    print("-" * 50)
    print("Testing...")

    x_tst = input_data['x_tst']
    y_tst = input_data['y_tst']
    n_eval = len(x_tst)

    cnn_params['mode'] = 'test-predict'
    cnn_params['load_param_node_name'] = params["current_depth"]

    if params["model_type"] == "XML-CNN":
        model = xml_cnn_model.CNN(**cnn_params)
    else:
        model = cnn_model.CNN(**cnn_params)

    model.to_gpu()
    output = np.zeros([n_eval, params["output_dimensions"]], dtype=np.int8)
    output_probability_file_name = "CNN/RESULT/probability_" + params[
        "current_depth"] + ".csv"
    with open(output_probability_file_name, 'w') as f:
        f.write(','.join(params["learning_categories"]) + "\n")

    test_batch_size = params["batchsize"]
    with chainer.using_config('train', False), chainer.no_backprop_mode():
        for i in tqdm(six.moves.range(0, n_eval, test_batch_size),
                      desc="Predict Test loop"):
            x = chainer.Variable(
                chainer.cuda.to_gpu(x_tst[i:i + test_batch_size]))
            t = y_tst[i:i + test_batch_size]
            net_output = F.sigmoid(model(x))
            output[i:i + test_batch_size] = select_function(net_output.data)
            with open(output_probability_file_name, 'a') as f:
                tmp = chainer.cuda.to_cpu(net_output.data)
                low_values_flags = tmp < 0.001
                tmp[low_values_flags] = 0
                np.savetxt(f, tmp, fmt='%.4g', delimiter=",")
    return output
Ejemplo n.º 8
0
def main(options):

    #load the config params
    gpu = options['gpu']
    data_path = options['path_dataset']
    embeddings_path = options['path_vectors']
    n_epoch = options['epochs']
    batch_size = options['batchsize']
    test = options['test']
    embed_dim = options['embed_dim']
    freeze = options['freeze_embeddings']
    distance_embed_dim = options['distance_embed_dim']

    #load the data
    data_processor = DataProcessor(data_path)
    data_processor.prepare_dataset()
    train_data = data_processor.train_data
    test_data = data_processor.test_data

    vocab = data_processor.vocab
    cnn = CNN(n_vocab=len(vocab),
              input_channel=1,
              output_channel=100,
              n_label=19,
              embed_dim=embed_dim,
              position_dims=distance_embed_dim,
              freeze=freeze)
    cnn.load_embeddings(embeddings_path, data_processor.vocab)
    model = L.Classifier(cnn)

    #use GPU if flag is set
    if gpu >= 0:
        model.to_gpu()

    #setup the optimizer
    optimizer = O.Adam()
    optimizer.setup(model)

    train_iter = chainer.iterators.SerialIterator(train_data, batch_size)
    test_iter = chainer.iterators.SerialIterator(test_data,
                                                 batch_size,
                                                 repeat=False,
                                                 shuffle=False)

    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       converter=convert.concat_examples,
                                       device=gpu)
    trainer = training.Trainer(updater, (n_epoch, 'epoch'))

    # Evaluation
    test_model = model.copy()
    test_model.predictor.train = False
    trainer.extend(
        extensions.Evaluator(test_iter,
                             test_model,
                             device=gpu,
                             converter=convert.concat_examples))

    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy'
        ]))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.run()
Ejemplo n.º 9
0
    def run(self,
            embeddings,
            dataset,
            path_output='/tmp/text_classification/'):
        self.out = path_output
        self.unit = embeddings.matrix.shape[1]

        if not os.path.isdir(path_output):
            os.makedirs(path_output)

        # TODO: move this to protonn ds management
        # self.path_dataset = path_dataset
        # if self.path_dataset == 'dbpedia':
        #     train, test, vocab = text_datasets.get_dbpedia(
        #         char_based=self.char_based,
        #         vocab=embeddings.vocabulary.dic_words_ids,
        #         shrink=self.shrink)
        # elif self.path_dataset.startswith('imdb.'):
        #     train, test, vocab = text_datasets.get_imdb(
        #         fine_grained=self.path_dataset.endswith('.fine'),
        #         char_based=self.char_based,
        #         vocab=embeddings.vocabulary.dic_words_ids,
        #         shrink=self.shrink)
        # elif self.path_dataset in ['TREC', 'stsa.binary', 'stsa.fine',
        #                            'custrev', 'mpqa', 'rt-polarity', 'subj']:
        #     train, test, vocab = text_datasets.get_other_text_dataset(
        #         self.path_dataset,
        #         char_based=self.char_based,
        #         vocab=embeddings.vocabulary.dic_words_ids,
        #         shrink=self.shrink)
        # else:  # finallly, if file is not downloadable, load from local path

        # TODO: make sure dataset module support adapter.py
        path_dataset = dataset.path
        print(path_dataset)
        path_adapter = os.path.join(path_dataset, "adapter.py")
        # TODO: get arrray of ids for train and test here
        if os.path.isfile(path_adapter):
            spec = importlib.util.spec_from_file_location(
                "ds_adapter", path_adapter)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            adapter = module.Adapter()
            train, test, _ = adapter.read()
            vocab = embeddings.vocabulary.dic_words_ids
            train = nlp_utils.transform_to_array(train, vocab)
            test = nlp_utils.transform_to_array(test, vocab)

        else:
            print("loading though DS")
            ds = Dataset(path_dataset)
            train = ds.get_train()
            train = [(word_tokenize_txt(i), j) for i, j in train]
            test = ds.get_test()
            test = [(word_tokenize_txt(i), j) for i, j in test]
            vocab = embeddings.vocabulary.dic_words_ids
            train = nlp_utils.transform_to_array(train, vocab)
            test = nlp_utils.transform_to_array(test, vocab)

        print('# cnt train samples: {}'.format(len(train)))
        print('# cnt test  samples: {}'.format(len(test)))
        print('# size vocab: {}'.format(len(vocab)))
        n_class = len(set([int(d[1]) for d in train]))
        print('# cnt classes: {}'.format(n_class))

        train_iter = chainer.iterators.SerialIterator(train, self.batchsize)
        test_iter = chainer.iterators.SerialIterator(test,
                                                     self.batchsize,
                                                     repeat=False,
                                                     shuffle=False)

        # Setup a model
        if self.model == 'rnn':
            Encoder = nets.RNNEncoder
        elif self.model == 'cnn':
            Encoder = nets.CNNEncoder
        elif self.model == 'bow':
            Encoder = nets.BOWMLPEncoder
        encoder = Encoder(n_layers=self.layer,
                          n_vocab=len(vocab),
                          n_units=self.unit,
                          dropout=self.dropout,
                          wv=embeddings.matrix)
        model = nets.TextClassifier(encoder, n_class)
        if self.gpu >= 0:
            # Make a specified GPU current
            chainer.backends.cuda.get_device_from_id(self.gpu).use()
            model.to_gpu()  # Copy the model to the GPU

        # Setup an optimizer
        optimizer = chainer.optimizers.Adam()
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(1e-4))

        # Set up a trainer
        updater = training.StandardUpdater(train_iter,
                                           optimizer,
                                           converter=nlp_utils.convert_seq,
                                           device=self.gpu)
        trainer = training.Trainer(updater, (self.epoch, 'epoch'),
                                   out=self.out)

        # Evaluate the model with the test dataset for each epoch
        trainer.extend(
            extensions.Evaluator(test_iter,
                                 model,
                                 converter=nlp_utils.convert_seq,
                                 device=self.gpu))

        # Take a best snapshot
        record_trigger = training.triggers.MaxValueTrigger(
            'validation/main/accuracy', (1, 'epoch'))
        trainer.extend(extensions.snapshot_object(model, 'best_model.npz'),
                       trigger=record_trigger)

        # Write a log of evaluation statistics for each epoch
        trainer.extend(extensions.LogReport())
        trainer.extend(
            extensions.PrintReport([
                'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
                'validation/main/accuracy', 'elapsed_time'
            ]))

        # Print a progress bar to stdout
        trainer.extend(extensions.ProgressBar())

        # Save vocabulary and model's setting
        if not os.path.isdir(self.out):
            os.mkdir(self.out)
        vocab_path = os.path.join(self.out, 'vocab.json')
        with open(vocab_path, 'w') as f:
            json.dump(vocab, f)
        model_path = os.path.join(self.out, 'best_model.npz')
        experiment_setup = self.__dict__
        # TODO: move all this to the parent class
        experiment_setup['task'] = "text classification"
        experiment_setup['vocab_path'] = vocab_path
        experiment_setup['model_path'] = model_path
        experiment_setup['n_class'] = n_class
        experiment_setup['datetime'] = self.current_datetime
        with open(os.path.join(self.out, 'args.json'), 'w') as f:
            json.dump(self.__dict__, f)

        # Run the training
        trainer.run()

        result = {}
        result['experiment_setup'] = experiment_setup
        result['experiment_setup']['default_measurement'] = 'accuracy'
        result['experiment_setup']['dataset'] = os.path.basename(
            os.path.normpath(path_dataset))
        result['experiment_setup']['method'] = self.model
        result['experiment_setup']['embeddings'] = embeddings.metadata
        result['log'] = load_json(os.path.join(self.out, 'log'))

        # TODO: old version was returning last test value, make a footnote
        # result['result'] = {"accuracy": result['log'][-1]['validation/main/accuracy']}
        accuracy = max(_["validation/main/accuracy"] for _ in result['log'])
        result['result'] = {"accuracy": accuracy}
        return [result]
Ejemplo n.º 10
0
def main():
    archs = {
        'alex': alex.Alex,
        'googlenet': googlenet.GoogLeNet,
        'googlenetbn': googlenetbn.GoogLeNetBN,
        'nin': nin.NIN
    }

    parser = argparse.ArgumentParser(
        description='Learning convnet from ILSVRC2012 dataset')
    parser.add_argument('train', help='Path to training image-label list file')
    parser.add_argument('val', help='Path to validation image-label list file')
    parser.add_argument('--arch',
                        '-a',
                        choices=archs.keys(),
                        default='nin',
                        help='Convnet architecture')
    parser.add_argument('--batchsize',
                        '-B',
                        type=int,
                        default=32,
                        help='Learning minibatch size')
    parser.add_argument('--epoch',
                        '-E',
                        type=int,
                        default=10,
                        help='Number of epochs to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU')
    parser.add_argument('--initmodel',
                        help='Initialize the model from given file')
    parser.add_argument('--loaderjob',
                        '-j',
                        type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--mean',
                        '-m',
                        default='mean.npy',
                        help='Mean file (computed by compute_mean.py)')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Initialize the trainer from given file')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Output directory')
    parser.add_argument('--root',
                        '-R',
                        default='.',
                        help='Root directory path of image files')
    parser.add_argument('--val_batchsize',
                        '-b',
                        type=int,
                        default=250,
                        help='Validation minibatch size')
    parser.add_argument('--test', action='store_true')
    parser.set_defaults(test=False)
    args = parser.parse_args()

    # Initialize the model to train
    model = archs[args.arch]()
    if args.initmodel:
        print('Load model from', args.initmodel)
        chainer.serializers.load_npz(args.initmodel, model)
    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make the GPU current
        model.to_gpu()

    # Load the datasets and mean file
    mean = np.load(args.mean)
    train = PreprocessedDataset(args.train, args.root, mean, model.insize)
    val = PreprocessedDataset(args.val, args.root, mean, model.insize, False)
    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    train_iter = chainer.iterators.MultiprocessIterator(
        train, args.batchsize, n_processes=args.loaderjob)
    val_iter = chainer.iterators.MultiprocessIterator(
        val, args.val_batchsize, repeat=False, n_processes=args.loaderjob)

    # Set up an optimizer
    optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
    optimizer.setup(model)

    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out)

    val_interval = (10 if args.test else 100000), 'iteration'
    log_interval = (10 if args.test else 1000), 'iteration'

    # Copy the chain with shared parameters to flip 'train' flag only in test
    eval_model = model.copy()
    eval_model.train = False

    trainer.extend(extensions.Evaluator(val_iter, eval_model, device=args.gpu),
                   trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(extensions.snapshot_object(
        model, 'model_iter_{.updater.iteration}'),
                   trigger=val_interval)
    # Be careful to pass the interval directly to LogReport
    # (it determines when to emit log rather than when to read observations)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.PrintReport([
        'epoch',
        'iteration',
        'main/loss',
        'validation/main/loss',
        'main/accuracy',
        'validation/main/accuracy',
    ]),
                   trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(
        description='chainer implementation of pix2pix')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=100,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset',
                        '-i',
                        default='../../Image',
                        help='Directory of image files.')
    parser.add_argument('--dataset_contour',
                        '-c',
                        default='../../Image_Contour',
                        help='Directory of contour image files')
    parser.add_argument('--data_num',
                        '-n',
                        type=int,
                        default=400,
                        help='number of data to use')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=20,
                        help='Interval epoch of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=1,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('# snapshot interval: {}'.format(args.snapshot_interval))
    print('')

    # Set up a neural network to train
    enc = Encoder(in_ch=3)
    dec = Decoder(out_ch=3)
    dis = Discriminator(in_ch=3, out_ch=3)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        enc.to_gpu()  # Copy the model to the GPU
        dec.to_gpu()
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
        return optimizer

    opt_enc = make_optimizer(enc)
    opt_dec = make_optimizer(dec)
    opt_dis = make_optimizer(dis)

    train_d, test_d = datasets.split_dataset_random(
        FacadeDataset(args.dataset, args.dataset_contour, args.data_num),
        args.data_num - 25)
    #train_iter = chainer.iterators.MultiprocessIterator(train_d, args.batchsize, n_processes=14)
    #test_iter = chainer.iterators.MultiprocessIterator(test_d, args.batchsize, n_processes=14)
    train_iter = chainer.iterators.SerialIterator(train_d,
                                                  args.batchsize,
                                                  shuffle=True)
    test_iter = chainer.iterators.SerialIterator(test_d,
                                                 args.batchsize,
                                                 shuffle=False)

    # Set up a trainer
    updater = FacadeUpdater(models=(enc, dec, dis),
                            iterator={
                                'main': train_iter,
                                'test': test_iter
                            },
                            optimizer={
                                'enc': opt_enc,
                                'dec': opt_dec,
                                'dis': opt_dis
                            },
                            device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
    snapshot_interval = (args.snapshot_interval, 'epoch')
    display_interval = (args.display_interval, 'epoch')
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        enc, 'enc_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dec, 'dec_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    # trainer.extend(extensions.snapshot_object(
    #     dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'enc/loss', 'dec/loss', 'dis/loss']),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.PlotReport(['enc/loss', 'dec/loss', 'dis/loss'],
                                         'epoch',
                                         file_name='loss.png'),
                   trigger=display_interval)
    trainer.extend(out_image(updater, enc, dec, 5, 5, args.seed, args.out),
                   trigger=snapshot_interval)

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Ejemplo n.º 12
0
                                              'gen_epoch_{.updater.epoch}.npz',
                                              savefun=save_npz),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(dis,
                                              'dis_epoch_{.updater.epoch}.npz',
                                              savefun=save_npz),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport([
        'epoch',
        'iteration',
        'gen/loss',
        'dis/loss',
        'elapsed_time',
    ]),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(out_generated_image(gen, dis, 5, 5, opt.seed, out),
                   trigger=display_interval)
    trainer.extend(
        extensions.PlotReport(['gen/loss', 'dis/loss'],
                              x_key='epoch',
                              file_name='loss_{0}_{1}.jpg'.format(
                                  opt.number, opt.seed),
                              grid=False))
    trainer.extend(extensions.dump_graph("gen/loss", out_name="gen.dot"))
    trainer.extend(extensions.dump_graph("dis/loss", out_name="dis.dot"))

    # Run the training
    trainer.run()
Ejemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser(description='Chainer-Tutorial: CNN')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=128,
                        help='Number of samples in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=100,
                        help='Number of times to train on data set')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID: -1 indicates CPU')
    parser.add_argument('--frequency',
                        '-f',
                        type=int,
                        default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    args = parser.parse_args()

    # Load mnist data
    # http://docs.chainer.org/en/latest/reference/datasets.html
    # Setting ndim=3, returns samples with shape (1, 28, 28)
    train, test = chainer.datasets.get_mnist(ndim=3)

    # Define iterators.
    train_iter = chainer.iterators.SerialIterator(train, args.batch_size)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 args.batch_size,
                                                 repeat=False,
                                                 shuffle=False)

    # Initialize model: Loss function defaults to softmax_cross_entropy.
    model = L.Classifier(MyCNN(10))

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        model.to_gpu()

    optimizer = chainer.optimizers.RMSprop(lr=0.001, alpha=0.9)
    optimizer.setup(model)

    # Set up trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'))

    # Evaluate the model at end of each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    trainer.extend(extensions.dump_graph('main/loss'))

    # Helper functions (extensions) to monitor progress on stdout.
    report_params = [
        'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
        'validation/main/accuracy', 'elapsed_time'
    ]
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(report_params))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.LogReport())

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Save two plot images to the result dir
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run trainer
    trainer.run()
Ejemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser(description='ChainerMN example: DCGAN')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=50,
                        help='Number of images in each mini-batch')
    parser.add_argument('--communicator',
                        type=str,
                        default='hierarchical',
                        help='Type of communicator')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', action='store_true', help='Use GPU')
    parser.add_argument('--dataset',
                        '-i',
                        default='',
                        help='Directory of image files.  Default is cifar-10.')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--gen_model',
                        '-r',
                        default='',
                        help='Use pre-trained generator for training')
    parser.add_argument('--dis_model',
                        '-d',
                        default='',
                        help='Use pre-trained discriminator for training')
    parser.add_argument('--n_hidden',
                        '-n',
                        type=int,
                        default=100,
                        help='Number of hidden units (z)')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    # Prepare ChainerMN communicator.

    if args.gpu:
        if args.communicator == 'naive':
            print("Error: 'naive' communicator does not support GPU.\n")
            exit(-1)
        comm = chainermn.create_communicator(args.communicator)
        device = comm.intra_rank
    else:
        if args.communicator != 'naive':
            print('Warning: using naive communicator '
                  'because only naive supports CPU-only execution')
        comm = chainermn.create_communicator('naive')
        device = -1

    if comm.rank == 0:
        print('==========================================')
        print('Num process (COMM_WORLD): {}'.format(comm.size))
        if args.gpu:
            print('Using GPUs')
        print('Using {} communicator'.format(args.communicator))
        print('Num hidden unit: {}'.format(args.n_hidden))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))
        print('==========================================')

    # Set up a neural network to train
    gen = Generator(n_hidden=args.n_hidden)
    dis = Discriminator()

    if device >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(device).use()
        gen.to_gpu()  # Copy the model to the GPU
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, comm, alpha=0.0002, beta1=0.5):
        # Create a multi node optimizer from a standard Chainer optimizer.
        optimizer = chainermn.create_multi_node_optimizer(
            chainer.optimizers.Adam(alpha=alpha, beta1=beta1), comm)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001), 'hook_dec')
        return optimizer

    opt_gen = make_optimizer(gen, comm)
    opt_dis = make_optimizer(dis, comm)

    # Split and distribute the dataset. Only worker 0 loads the whole dataset.
    # Datasets of worker 0 are evenly split and distributed to all workers.
    if comm.rank == 0:
        if args.dataset == '':
            # Load the CIFAR10 dataset if args.dataset is not specified
            train, _ = chainer.datasets.get_cifar10(withlabel=False,
                                                    scale=255.)
        else:
            all_files = os.listdir(args.dataset)
            image_files = [f for f in all_files if ('png' in f or 'jpg' in f)]
            print('{} contains {} image files'.format(args.dataset,
                                                      len(image_files)))
            train = chainer.datasets\
                .ImageDataset(paths=image_files, root=args.dataset)
    else:
        train = None

    train = chainermn.scatter_dataset(train, comm)

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Set up a trainer
    updater = DCGANUpdater(models=(gen, dis),
                           iterator=train_iter,
                           optimizer={
                               'gen': opt_gen,
                               'dis': opt_dis
                           },
                           device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Some display and output extensions are necessary only for one worker.
    # (Otherwise, there would just be repeated outputs.)
    if comm.rank == 0:
        snapshot_interval = (args.snapshot_interval, 'iteration')
        display_interval = (args.display_interval, 'iteration')
        # Save only model parameters.
        # `snapshot` extension will save all the trainer module's attribute,
        # including `train_iter`.
        # However, `train_iter` depends on scattered dataset, which means that
        # `train_iter` may be different in each process.
        # Here, instead of saving whole trainer module, only the network models
        # are saved.
        trainer.extend(extensions.snapshot_object(
            gen, 'gen_iter_{.updater.iteration}.npz'),
                       trigger=snapshot_interval)
        trainer.extend(extensions.snapshot_object(
            dis, 'dis_iter_{.updater.iteration}.npz'),
                       trigger=snapshot_interval)
        trainer.extend(extensions.LogReport(trigger=display_interval))
        trainer.extend(extensions.PrintReport([
            'epoch',
            'iteration',
            'gen/loss',
            'dis/loss',
            'elapsed_time',
        ]),
                       trigger=display_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))
        trainer.extend(out_generated_image(gen, dis, 10, 10, args.seed,
                                           args.out),
                       trigger=snapshot_interval)

    # Start the training using pre-trained model, saved by snapshot_object
    if args.gen_model:
        chainer.serializers.load_npz(args.gen_model, gen)
    if args.dis_model:
        chainer.serializers.load_npz(args.dis_model, dis)

    # Run the training
    trainer.run()
Ejemplo n.º 15
0
def main(args):
    if args.model == 'ssd300':
        model = SSD300(n_fg_class=len(place_labels),
                       pretrained_model='imagenet')
    elif args.model == 'ssd512':
        model = SSD512(n_fg_class=len(place_labels),
                       pretrained_model='imagenet')

    model.use_preset('evaluate')
    train_chain = MultiboxTrainChain(model)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    train = TransformDataset(DotaBboxDataset(split='train'),
                             Transform(model.coder, model.insize, model.mean))
    train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize)

    test = DotaBboxDataset(split='test',
                           use_difficult=True,
                           return_difficult=True)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 args.batchsize,
                                                 repeat=False,
                                                 shuffle=False)

    # initial lr is set to 1e-3 by ExponentialShift
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(train_chain)
    for param in train_chain.params():
        if param.name == 'b':
            param.update_rule.add_hook(GradientScaling(2))
        else:
            param.update_rule.add_hook(WeightDecay(0.0005))

    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (ITERATION, 'iteration'), args.out)
    trainer.extend(extensions.ExponentialShift('lr', 0.1, init=1e-3),
                   trigger=triggers.ManualScheduleTrigger([80000, 100000],
                                                          'iteration'))

    trainer.extend(DetectionVOCEvaluator(test_iter,
                                         model,
                                         use_07_metric=True,
                                         label_names=place_labels),
                   trigger=(VALIDATE_INTERVAL, 'iteration'))

    log_interval = 10, 'iteration'
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'lr', 'main/loss', 'main/loss/loc',
        'main/loss/conf', 'validation/main/map'
    ]),
                   trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(extensions.snapshot(),
                   trigger=(SNAPSHOT_INTERVAL, 'iteration'))
    trainer.extend(extensions.snapshot_object(
        model, 'model_iter_{.updater.iteration}'),
                   trigger=(ITERATION, 'iteration'))

    if args.resume:
        serializers.load_npz(args.resume, trainer)

    trainer.run()
Ejemplo n.º 16
0
def main():
    args = parse_args()
    dump_args(args)

    # setup model
    n_classes = len(KuzushijiUnicodeMapping())
    if args.model == 'resnet18':
        model = Resnet18(n_classes)
    elif args.model == 'resnet34':
        model = Resnet34(n_classes)
    elif args.model == 'mobilenetv3':
        model = MobileNetV3(n_classes)
    train_model = L.Classifier(model)

    if args.gpu >= 0:
        chainer.backends.cuda.get_device(args.gpu).use()
        train_model.to_gpu()

    # setup dataset
    train, val = prepare_dataset(image_size=model.input_size,
                                 full_data=args.full_data)
    train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize)
    val_iter = chainer.iterators.MultiprocessIterator(val,
                                                      args.batchsize,
                                                      repeat=False,
                                                      shuffle=False)

    # setup optimizer
    optimizer = chainer.optimizers.NesterovAG(lr=args.lr, momentum=0.9)
    optimizer.setup(train_model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    # setup trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(val_iter, train_model,
                                        device=args.gpu))
    trainer.extend(extensions.snapshot(), trigger=(100, 'epoch'))
    trainer.extend(extensions.snapshot_object(model,
                                              'model_{.updater.epoch}.npz'),
                   trigger=(100, 'epoch'))
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy'
        ]))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    # learning rate scheduling
    lr_drop_epochs = [int(args.epoch * 0.5), int(args.epoch * 0.75)]
    lr_drop_trigger = triggers.ManualScheduleTrigger(lr_drop_epochs, 'epoch')
    trainer.extend(LearningRateDrop(0.1), trigger=lr_drop_trigger)
    trainer.extend(extensions.observe_lr())

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # start training
    trainer.run()
Ejemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser(description='Chainer CIFAR example:')
    parser.add_argument('--dataset', default='cifar10',
                        help='The dataset to use: cifar10 or cifar100')
    parser.add_argument('--batchsize', '-b', type=int, default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--learnrate', '-l', type=float, default=0.05,
                        help='Learning rate for SGD')
    parser.add_argument('--epoch', '-e', type=int, default=300,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--device', '-d', type=str, default='-1',
                        help='Device specifier. Either ChainerX device '
                        'specifier or an integer. If non-negative integer, '
                        'CuPy arrays with specified device id are used. If '
                        'negative integer, NumPy arrays are used')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--early-stopping', type=str,
                        help='Metric to watch for early stopping')
    group = parser.add_argument_group('deprecated arguments')
    group.add_argument('--gpu', '-g', dest='device',
                       type=int, nargs='?', const=0,
                       help='GPU ID (negative value indicates CPU)')
    args = parser.parse_args()

    device = chainer.get_device(args.device)
    device.use()

    print('Device: {}'.format(device))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train.
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    if args.dataset == 'cifar10':
        print('Using CIFAR10 dataset.')
        class_labels = 10
        train, test = get_cifar10()
    elif args.dataset == 'cifar100':
        print('Using CIFAR100 dataset.')
        class_labels = 100
        train, test = get_cifar100()
    else:
        raise RuntimeError('Invalid dataset choice.')
    model = L.Classifier(models.VGG.VGG(class_labels))
    model.to_device(device)

    optimizer = chainer.optimizers.MomentumSGD(args.learnrate)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(5e-4))

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    stop_trigger = (args.epoch, 'epoch')
    # Early stopping option
    if args.early_stopping:
        stop_trigger = triggers.EarlyStoppingTrigger(
            monitor=args.early_stopping, verbose=True,
            max_trigger=(args.epoch, 'epoch'))

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=device))

    # Reduce the learning rate by half every 25 epochs.
    trainer.extend(extensions.ExponentialShift('lr', 0.5),
                   trigger=(25, 'epoch'))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    # TODO(imanishi): Support for ChainerX
    if not isinstance(device, backend.ChainerxDevice):
        trainer.extend(extensions.DumpGraph('main/loss'))

    # Take a snapshot at each epoch
    trainer.extend(extensions.snapshot(
        filename='snaphot_epoch_{.updater.epoch}'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Ejemplo n.º 18
0
            '_epoch_{.updater.epoch}.trainer'
    else:
        save_fn = 'encdec4_finetune_epoch_' + '{.updater.epoch}.trainer'
    trainer.extend(extensions.snapshot(
        filename=save_fn, trigger=(args.snapshot_epoch, 'epoch')),
        priority=0, invoke_before_training=False)

    # Add Logger
    if not args.finetune:
        log_fn = 'log_encdec{}.0'.format(args.train_depth)
    else:
        log_fn = 'log_encdec_finetune.0'
    if os.path.exists('{}/{}'.format(result_dir, log_fn)):
        n = int(log_fn.split('.')[-1])
        log_fn = log_fn.replace(str(n), str(n + 1))
    trainer.extend(extensions.ProgressBar())
    if args.show_log_iter:
        log_trigger = args.show_log_iter, 'iteration'
    else:
        log_trigger = 1, 'epoch'
    trainer.extend(extensions.LogReport(trigger=log_trigger, log_name=log_fn))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'main/loss', 'validation/main/loss']))

    # Add remover and recoverer
    if not args.finetune:
        trainer.extend(remove_links, trigger=(args.snapshot_epoch, 'epoch'),
                       priority=500, invoke_before_training=True)
        trainer.extend(recover_links, trigger=(args.snapshot_epoch, 'epoch'),
                       priority=400, invoke_before_training=False)
Ejemplo n.º 19
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    model = E2E(idim, odim, args)
    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' % (
            args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(
            model.parameters(), rho=0.95, eps=args.eps,
            weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad)
    valid = make_batchset(valid_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False,
            n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, converter.transform),
            batch_size=1, shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(
        model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu)
    trainer = training.Trainer(
        updater, (args.epochs, 'epoch'), out=args.outdir)

    if use_sortagrad:
        trainer.extend(ShufflingEnabler([train_iter]),
                       trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
        else:
            att_vis_fn = model.calculate_all_attentions
        att_reporter = PlotAttentionReport(
            att_vis_fn, data, args.outdir + "/att_ws",
            converter=converter, device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss',
                                          'main/loss_ctc', 'validation/main/loss_ctc',
                                          'main/loss_att', 'validation/main/loss_att'],
                                         'epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'],
                                         'epoch', file_name='acc.png'))
    trainer.extend(extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                                         'epoch', file_name='cer.png'))

    # Save best models
    trainer.extend(extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save),
                   trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save),
                       trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration')))
    report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
                   'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att',
                   'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc',
                   'elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]),
            trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(
        report_keys), trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Ejemplo n.º 20
0
optimizer = O.Adam()
optimizer.setup(model)

train_iter = I.SerialIterator(train, args.batchsize)
val_iter = I.SerialIterator(val, args.batchsize, repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer,
                                   converter, args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'))

trainer.extend(E.Evaluator(val_iter, model, converter, args.gpu),
               trigger=(5, 'iteration'))
trainer.extend(E.LogReport(trigger=(5, 'iteration')))

if E.PlotReport.available():
    trainer.extend(
        E.PlotReport(['main/loss', 'validation/main/loss'], 'epoch',
                     file_name='loss.png'))
    trainer.extend(
        E.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch',
                     file_name='accuracy.png'))

trainer.extend(
    E.PrintReport(['epoch', 'iteration', 'main/loss', 'validation/main/loss',
                   'main/accuracy', 'validation/main/accuracy',
                   'elapsed_time']), trigger=(5, 'iteration'))
trainer.extend(E.ProgressBar(update_interval=1))

trainer.run()
Ejemplo n.º 21
0
def main():
    #parser = argparse.ArgumentParser(
    #   description='ChainerCV training example: Faster R-CNN')
    #parser.add_argument('--gpu', '-g', type=int, default=-1)
    #parser.add_argument('--lr', '-l', type=float, default=1e-3)
    #parser.add_argument('--out', '-o', default='result',
    #                    help='Output directory')
    #parser.add_argument('--seed', '-s', type=int, default=0)
    #parser.add_argument('--step_size', '-ss', type=int, default=50000)
    #parser.add_argument('--iteration', '-i', type=int, default=70000)
    #parser.add_argument('--train_data_dir', '-t', default=WIDER_TRAIN_DIR,
    #                    help='Training dataset (WIDER_train)')
    #parser.add_argument('--train_annotation', '-ta', default=WIDER_TRAIN_ANNOTATION_MAT,
    #                    help='Annotation file (.mat) for training dataset')
    #parser.add_argument('--val_data_dir', '-v', default=WIDER_VAL_DIR,
    #                    help='Validation dataset (WIDER_train)')
    #parser.add_argument('--val_annotation', '-va', default=WIDER_VAL_ANNOTATION_MAT,
    #                    help='Annotation file (.mat) for validation dataset')
    #args = parser.parse_args()

    #np.random.seed(args.seed)
    np.random.seed(0)

    # for logging pocessed files
    logger = logging.getLogger('logger')
    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(filename='filelog.log')
    handler.setLevel(logging.DEBUG)
    logger.addHandler(handler)
    print('logger')
    blacklist = []
    with open(BLACKLIST_FILE, 'r') as f:
        for line in f:
            l = line.strip()
            if l:
                blacklist.append(line.strip())

    # train_data = VOCDetectionDataset(split='trainval', year='2007')
    # test_data = VOCDetectionDataset(split='test', year='2007',
    # use_difficult=True, return_difficult=True)
    #train_data = WIDERFACEDataset(args.train_data_dir, args.train_annotation,
    #    logger=logger, exclude_file_list=blacklist)
    #test_data = WIDERFACEDataset(args.val_data_dir, args.val_annotation)

    train_data = WIDERFACEDataset(WIDER_TRAIN_DIR,
                                  WIDER_TRAIN_ANNOTATION_MAT,
                                  logger=logger,
                                  exclude_file_list=blacklist)
    test_data = WIDERFACEDataset(WIDER_VAL_DIR, WIDER_VAL_ANNOTATION_MAT)

    # faster_rcnn = FasterRCNNVGG16(n_fg_class=len(voc_detection_label_names),
    # pretrained_model='imagenet')
    faster_rcnn.use_preset('evaluate')
    model = FasterRCNNTrainChain(faster_rcnn)
    if 0 >= 0:
        model.to_gpu(-1)
        #model.to_cpu()
        chainer.cuda.get_device(-1).use()
        #chainer.cuda.get_device_from_array()
        #chainer.cuda.to_cpu()
    #optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9)
    optimizer = chainer.optimizers.MomentumSGD(lr=1e-3, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005))

    train_data = TransformDataset(train_data, transform)
    #import pdb; pdb.set_trace()
    #train_iter = chainer.iterators.MultiprocessIterator(
    #    train_data, batch_size=1, n_processes=None, shared_mem=100000000)
    train_iter = chainer.iterators.SerialIterator(train_data, batch_size=1)
    test_iter = chainer.iterators.SerialIterator(test_data,
                                                 batch_size=1,
                                                 repeat=False,
                                                 shuffle=False)
    updater = chainer.training.updater.StandardUpdater(train_iter,
                                                       optimizer,
                                                       device=0)

    trainer = training.Trainer(updater, (70000, 'iteration'), out='result')

    trainer.extend(extensions.snapshot_object(model.faster_rcnn,
                                              'snapshot_model.npz'),
                   trigger=(70000, 'iteration'))
    trainer.extend(extensions.ExponentialShift('lr', 0.1),
                   trigger=(50000, 'iteration'))

    log_interval = 20, 'iteration'
    plot_interval = 3000, 'iteration'
    print_interval = 20, 'iteration'

    trainer.extend(chainer.training.extensions.observe_lr(),
                   trigger=log_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.PrintReport([
        'iteration',
        'epoch',
        'elapsed_time',
        'lr',
        'main/loss',
        'main/roi_loc_loss',
        'main/roi_cls_loss',
        'main/rpn_loc_loss',
        'main/rpn_cls_loss',
        'validation/main/map',
    ]),
                   trigger=print_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    if extensions.PlotReport.available():
        trainer.extend(extensions.PlotReport(['main/loss'],
                                             file_name='loss.png',
                                             trigger=plot_interval),
                       trigger=plot_interval)

    trainer.extend(DetectionVOCEvaluator(test_iter,
                                         model.faster_rcnn,
                                         use_07_metric=True,
                                         label_names=('face', )),
                   trigger=ManualScheduleTrigger([50000, 70000], 'iteration'))

    trainer.extend(extensions.dump_graph('main/loss'))

    #try:
    # warnings.filterwarnings('error', category=RuntimeWarning)
    trainer.run()
Ejemplo n.º 22
0
def main():
    archs = {
        'alex': alex.Alex,
        'alex_fp16': alex.AlexFp16,
        'googlenet': googlenet2.GoogLeNet,
        'googlenetbn': googlenetbn.GoogLeNetBN,
        'googlenetbn_fp16': googlenetbn.GoogLeNetBNFp16,
        'nin': nin.NIN,
        'resnet50': resnet50.ResNet50
    }

    parser = argparse.ArgumentParser(
        description='Learning convnet from ILSVRC2012 dataset')
    parser.add_argument('train', help='Path to training image-label list file')
    parser.add_argument('val', help='Path to validation image-label list file')
    parser.add_argument('--arch',
                        '-a',
                        choices=archs.keys(),
                        default='nin',
                        help='Convnet architecture')
    parser.add_argument('--batchsize',
                        '-B',
                        type=int,
                        default=32,
                        help='Learning minibatch size')
    parser.add_argument('--epoch',
                        '-E',
                        type=int,
                        default=10,
                        help='Number of epochs to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU')
    parser.add_argument('--initmodel',
                        help='Initialize the model from given file')
    parser.add_argument('--loaderjob',
                        '-j',
                        type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--mean',
                        '-m',
                        default='mean.npy',
                        help='Mean file (computed by compute_mean.py)')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Initialize the trainer from given file')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Output directory')
    parser.add_argument('--root',
                        '-R',
                        default='.',
                        help='Root directory path of image files')
    parser.add_argument('--optimizer', default='adam', help='optimizer')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0001,
                        help='weight decay')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-3,
                        help='learning rate. if adam, it is mean alpha')
    parser.add_argument('--lr_shift',
                        type=float,
                        default=0.5,
                        help='lr exponential shift. 0 mean not to shift')
    parser.add_argument('--val_batchsize',
                        '-b',
                        type=int,
                        default=250,
                        help='Validation minibatch size')
    parser.add_argument('--test', action='store_true')
    parser.set_defaults(test=False)
    args = parser.parse_args()

    model_cls = archs[args.arch]

    # Load the datasets and mean file
    insize = model_cls.insize
    mean = np.load(args.mean)
    train = PreprocessedDataset(args.train, args.root, mean, insize)
    val = PreprocessedDataset(args.val, args.root, mean, insize, False)
    outsize = len(set(pd.read_csv(args.train, sep=' ', header=None)[1]))

    # Initialize the model to train
    if args.arch == 'googlenet':
        model = model_cls(output_size=outsize)
    else:
        model = model_cls()
    if args.initmodel:
        print('Load model from', args.initmodel)
        try:
            chainer.serializers.load_npz(args.initmodel, model)
        except (ValueError, KeyError) as e:
            print('not match model. try default GoogLeNet. "{}"'.format(e))
            src_model = googlenet.GoogLeNet()
            chainer.serializers.load_npz(args.initmodel, src_model)
            copy_model(src_model, model)
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()  # Make the GPU current
        model.to_gpu()

    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    train_iter = chainer.iterators.MultiprocessIterator(
        train, args.batchsize, n_processes=args.loaderjob)
    val_iter = chainer.iterators.MultiprocessIterator(
        val, args.val_batchsize, repeat=False, n_processes=args.loaderjob)

    # Set up an optimizer
    # optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
    print('set optimizer: {}, learning rate: {}'.format(
        args.optimizer, args.learning_rate))
    if args.optimizer == 'adam':
        optimizer = chainer.optimizers.Adam(alpha=args.learning_rate)
    else:
        optimizer = chainer.optimizers.MomentumSGD(lr=args.learning_rate,
                                                   momentum=0.9)

    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out)

    val_interval = (1 if args.test else 100000), 'iteration'
    log_interval = (1 if args.test else 1000), 'iteration'
    test_interval = 1, 'epoch'

    trainer.extend(extensions.Evaluator(val_iter, model, device=args.gpu),
                   trigger=test_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    #trainer.extend(extensions.snapshot(), trigger=val_interval)
    #trainer.extend(extensions.snapshot_object(
    #    model, 'model_iter_{.updater.iteration}'), trigger=val_interval)
    #trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
    trainer.extend(
        extensions.snapshot_object(model,
                                   filename='model_epoch-{.updater.epoch}'))
    # Be careful to pass the interval directly to LogReport
    # (it determines when to emit log rather than when to read observations)
    trainer.extend(extensions.LogReport(trigger=test_interval))
    trainer.extend(extensions.observe_lr(), trigger=test_interval)
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'lr'
    ]),
                   trigger=test_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    if args.lr_shift > 0:
        # Reduce the learning rate by half every 25 epochs.
        if args.optimizer == 'adam':
            trainer.extend(extensions.ExponentialShift('alpha', args.lr_shift),
                           trigger=(25, 'epoch'))
        else:
            trainer.extend(extensions.ExponentialShift('lr', args.lr_shift),
                           trigger=(25, 'epoch'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
Ejemplo n.º 23
0
def main():
    import numpy as np
    import argparse

    # パーサーを作る
    parser = argparse.ArgumentParser(
        prog='train',  # プログラム名
        usage='train DCGAN to dish',  # プログラムの利用方法
        description='description',  # 引数のヘルプの前に表示
        epilog='end',  # 引数のヘルプの後で表示
        add_help=True,  # -h/–help オプションの追加
    )

    # 引数の追加
    parser.add_argument('-s', '--seed', help='seed', type=int, required=True)
    parser.add_argument('-n',
                        '--number',
                        help='the number of experiments.',
                        type=int,
                        required=True)
    parser.add_argument('--hidden',
                        help='the number of codes of Generator.',
                        type=int,
                        default=100)
    parser.add_argument('-e',
                        '--epoch',
                        help='the number of epoch, defalut value is 300',
                        type=int,
                        default=300)
    parser.add_argument('-bs',
                        '--batch_size',
                        help='batch size. defalut value is 128',
                        type=int,
                        default=128)
    parser.add_argument('-g',
                        '--gpu',
                        help='specify gpu by this number. defalut value is 0',
                        choices=[0, 1],
                        type=int,
                        default=0)
    parser.add_argument(
        '-ks',
        '--ksize',
        help='specify ksize of generator by this number. any of following;'
        ' 4 or 6. defalut value is 6',
        choices=[4, 6],
        type=int,
        default=6)
    parser.add_argument(
        '-dis',
        '--discriminator',
        help='specify discriminator by this number. any of following;'
        ' 0: original, 1: minibatch discriminatio, 2: feature matching, 3: GAP. defalut value is 3',
        choices=[0, 1, 2, 3],
        type=int,
        default=3)
    parser.add_argument(
        '-ts',
        '--tensor_shape',
        help=
        'specify Tensor shape by this numbers. first args denotes to B, seconds to C.'
        ' defalut value are B:32, C:8',
        type=int,
        default=[32, 8],
        nargs=2)
    parser.add_argument('-V',
                        '--version',
                        version='%(prog)s 1.0.0',
                        action='version',
                        default=False)

    # 引数を解析する
    args = parser.parse_args()

    gpu = args.gpu
    batch_size = args.batch_size
    n_hidden = args.hidden
    epoch = args.epoch
    seed = args.seed
    number = args.number  # number of experiments
    if args.ksize == 6:
        pad = 2
    else:
        pad = 1

    out = pathlib.Path("result_{0}".format(number))
    if not out.exists():
        out.mkdir()
    out /= pathlib.Path("result_{0}_{1}".format(number, seed))
    if not out.exists():
        out.mkdir()

    # 引数(ハイパーパラメータの設定)の書き出し
    with open(out / "args.txt", "w") as f:
        f.write(str(args))

    print('GPU: {}'.format(gpu))
    print('# Minibatch-size: {}'.format(batch_size))
    print('# n_hidden: {}'.format(n_hidden))
    print('# epoch: {}'.format(epoch))
    print('# out: {}'.format(out))
    print('# seed: {}'.format(seed))
    print('# ksize: {}'.format(args.ksize))
    print('# pad: {}'.format(pad))

    # fix seed
    np.random.seed(seed)
    if chainer.backends.cuda.available:
        chainer.backends.cuda.cupy.random.seed(seed)

    # import discrimination & set up
    # if args.discriminator == 0:
    #     print("# Original Discriminator")
    #     from discriminator import Discriminator
    #     from updater import DCGANUpdater
    #     dis = Discriminator()
    if args.discriminator == 1:
        print("# Discriminator applied Minibatch Discrimination")
        print('# Tensor shape is A x {0} x {1}'.format(args.tensor_shape[0],
                                                       args.tensor_shape[1]))
        from discriminator_md import Discriminator
        from updater import DCGANUpdater
        dis = Discriminator(B=args.tensor_shape[0], C=args.tensor_shape[1])
    elif args.discriminator == 3:
        print("# Discriminator applied GAP")
        from discriminator import Discriminator
        from updater import DCGANUpdater
        dis = Discriminator()
    """
    elif args.discriminator == 2:
        print("# Discriminator applied matching")
        from discriminator_fm import Discriminator
        from updater_fm import DCGANUpdater
    """
    print('')
    # Set up a neural network to train
    gen = Generator(n_hidden=n_hidden, ksize=args.ksize, pad=pad)

    if gpu >= 0:
        # Make a specified GPU current
        chainer.backends.cuda.get_device_from_id(gpu).use()
        gen.to_gpu()  # Copy the model to the GPU
        dis.to_gpu()

    opt_gen = make_optimizer(gen)
    opt_dis = make_optimizer(dis)

    # Prepare Dataset
    """
    paths = ["rsize_data_128", "test_rsize_data_128",
             "unlabeled_rsize_data_128"]  # resize data 128
    """
    paths = ["center_crop_data_128"]  # center ctop data
    data_path = []
    for path in paths:
        data_dir = pathlib.Path(path)
        abs_data_dir = data_dir.resolve()
        print("data dir path:", abs_data_dir)
        data_path += [path for path in abs_data_dir.glob("*.jpg")]
    print("data length:", len(data_path))
    data = ImageDataset(paths=data_path)  # dtype=np.float32
    train_iter = chainer.iterators.SerialIterator(data, batch_size)

    # Set up a updater and trainer
    updater = DCGANUpdater(models=(gen, dis),
                           iterator=train_iter,
                           optimizer={
                               'gen': opt_gen,
                               'dis': opt_dis
                           },
                           device=gpu)
    trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)

    snapshot_interval = (10, 'epoch')
    display_interval = (1, 'epoch')
    # storage method is hdf5
    trainer.extend(extensions.snapshot(
        filename='snapshot_epoch_{.updater.epoch}.npz', savefun=save_npz),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(gen,
                                              'gen_epoch_{.updater.epoch}.npz',
                                              savefun=save_npz),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(dis,
                                              'dis_epoch_{.updater.epoch}.npz',
                                              savefun=save_npz),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'gen/loss', 'dis/loss', 'elapsed_time',
        'dis/accuracy'
    ]),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(out_generated_image(gen, dis, 5, 5, seed, out),
                   trigger=display_interval)
    trainer.extend(
        extensions.PlotReport(['gen/loss', 'dis/loss'],
                              x_key='epoch',
                              file_name='loss_{0}_{1}.jpg'.format(
                                  number, seed),
                              grid=False))
    trainer.extend(extensions.dump_graph("gen/loss", out_name="gen.dot"))
    trainer.extend(extensions.dump_graph("dis/loss", out_name="dis.dot"))

    # Run the training
    trainer.run()
Ejemplo n.º 24
0
def main():
    parser = argparse.ArgumentParser(description="Vanilla_AE")
    parser.add_argument("--batchsize", "-b", type=int, default=64)
    parser.add_argument("--epoch", "-e", type=int, default=100)
    parser.add_argument("--gpu", "-g", type=int, default=0)
    parser.add_argument("--snapshot", "-s", type=int, default=10)
    parser.add_argument("--n_dimz", "-z", type=int, default=16)
    parser.add_argument("--dataset", "-d", type=str, default='mnist')
    parser.add_argument("--network", "-n", type=str, default='conv')

    args = parser.parse_args()

    def transform(in_data):
        img = in_data
        img = resize(img, (32, 32))
        return img

    def transform2(in_data):
        img, label = in_data
        img = resize(img, (32, 32))
        return img, label

    #import program
    import Updater
    import Visualizer

    #print settings
    print("GPU:{}".format(args.gpu))
    print("epoch:{}".format(args.epoch))
    print("Minibatch_size:{}".format(args.batchsize))
    print('')
    out = os.path.join('result', args.network)
    batchsize = args.batchsize
    gpu_id = args.gpu
    max_epoch = args.epoch

    train_val, _ = mnist.get_mnist(withlabel=False, ndim=3)
    train_val = TransformDataset(train_val, transform)
    #for visualize
    _, test = mnist.get_mnist(withlabel=True, ndim=3)
    test = TransformDataset(test, transform2)
    label1 = 1
    label2 = 5
    test1 = [i[0] for i in test if (i[1] == label1)]
    test2 = [i[0] for i in test if (i[1] == label2)]
    test1 = test1[0:5]
    test2 = test2[5:10]

    if args.network == 'conv':
        import Network.mnist_conv as Network
    elif args.network == 'fl':
        import Network.mnist_fl as Network
    else:
        raise Exception('Error!')

    AE = Network.AE(n_dimz=args.n_dimz, batchsize=args.batchsize)
    Critic = Network.Critic()
    train, valid = split_dataset_random(train_val, 50000, seed=0)

    #set iterator
    train_iter = iterators.SerialIterator(train, batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          batchsize,
                                          repeat=False,
                                          shuffle=False)

    #optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001))
        return optimizer

    opt_AE = make_optimizer(AE)
    opt_Critic = make_optimizer(Critic)
    #trainer
    updater = Updater.ACAIUpdater(model=(AE, Critic),
                                  iterator=train_iter,
                                  optimizer={
                                      'AE': opt_AE,
                                      'Critic': opt_Critic
                                  },
                                  device=args.gpu)

    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out=out)
    trainer.extend(extensions.LogReport(log_name='log'))
    snapshot_interval = (args.snapshot, 'epoch')
    display_interval = (1, 'epoch')
    trainer.extend(extensions.snapshot_object(
        AE, filename='AE_snapshot_epoch_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        Critic, filename='Critic_snapshot_epoch_{.updater.epoch}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.PrintReport(
        ['epoch', 'Critic_loss', 'AE_loss', 'rec_loss']),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar())
    trainer.extend(Visualizer.out_generated_image(AE, Critic, test1, test2,
                                                  out),
                   trigger=(1, 'epoch'))
    trainer.run()
    del trainer
Ejemplo n.º 25
0
def main():
    parser = argparse.ArgumentParser(description='Chainer example: DCGAN')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=50,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=1000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset',
                        '-i',
                        default='',
                        help='Directory of image files.  Default is cifar-10.')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--n_hidden',
                        '-n',
                        type=int,
                        default=100,
                        help='Number of hidden units (z)')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed of z at visualization stage')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=1000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# n_hidden: {}'.format(args.n_hidden))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    gen = Generator(n_hidden=args.n_hidden)
    dis = Discriminator()

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        gen.to_gpu()  # Copy the model to the GPU
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001), 'hook_dec')
        return optimizer

    opt_gen = make_optimizer(gen)
    opt_dis = make_optimizer(dis)

    if args.dataset == '':
        # Load the CIFAR10 dataset if args.dataset is not specified
        train, _ = chainer.datasets.get_cifar10(withlabel=False, scale=255.)
    else:
        all_files = os.listdir(args.dataset)
        image_files = [f for f in all_files if ('png' in f or 'jpg' in f)]
        print('{} contains {} image files'.format(args.dataset,
                                                  len(image_files)))
        train = chainer.datasets\
            .ImageDataset(paths=image_files, root=args.dataset)

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Set up a trainer
    updater = DCGANUpdater(models=(gen, dis),
                           iterator=train_iter,
                           optimizer={
                               'gen': opt_gen,
                               'dis': opt_dis
                           },
                           device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    display_interval = (args.display_interval, 'iteration')
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        gen, 'gen_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        'epoch',
        'iteration',
        'gen/loss',
        'dis/loss',
    ]),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(out_generated_image(gen, dis, 10, 10, args.seed, args.out),
                   trigger=snapshot_interval)

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Ejemplo n.º 26
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=20,
                        help='Number of examples in each mini-batch')
    parser.add_argument('--bproplen',
                        '-l',
                        type=int,
                        default=35,
                        help='Number of words in each mini-batch '
                        '(= length of truncated BPTT)')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=39,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--pretrain',
                        default=0,
                        help='Pretrain (w/o VD) or not (w/ VD).' +
                        ' default is not (0).')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--test',
                        action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.set_defaults(test=False)
    parser.add_argument('--unit',
                        '-u',
                        type=int,
                        default=650,
                        help='Number of LSTM units in each layer')
    args = parser.parse_args()

    # Load the Penn Tree Bank long word sequence dataset
    train, val, test = chainer.datasets.get_ptb_words()
    n_vocab = max(train) + 1  # train is just an array of integers
    print('#vocab =', n_vocab)

    if args.test:
        train = train[:1000]
        val = val[:1000]
        test = test[:1000]

    train_iter = ParallelSequentialIterator(train, args.batchsize)
    val_iter = ParallelSequentialIterator(val, 1, repeat=False)
    test_iter = ParallelSequentialIterator(test, 1, repeat=False)
    print('# of train:', len(train))
    n_iters = len(train) // args.batchsize // args.bproplen
    print('# of train batch/epoch:', n_iters)

    # Prepare an RNNLM model
    if args.pretrain:
        model = nets.RNNForLM(n_vocab, args.unit)

        def calc_loss(x, t):
            model.y = model(x)
            model.loss = F.softmax_cross_entropy(model.y, t)
            reporter.report({'loss': model.loss}, model)
            reporter.report({'class': model.loss}, model)
            model.accuracy = F.accuracy(model.y, t)
            reporter.report({'accuracy': model.accuracy}, model)
            return model.loss

        model.calc_loss = calc_loss
        model.use_raw_dropout = True
    elif args.resume:
        model = nets.RNNForLMVD(n_vocab, args.unit, warm_up=1e-5)
        # model.to_variational_dropout()
        chainer.serializers.load_npz(args.resume, model)
        if args.bproplen <= 20:
            configuration.config.user_memory_efficiency = 0
        else:
            configuration.config.user_memory_efficiency = 3
    else:
        model = nets.RNNForLMVD(n_vocab, args.unit, warm_up=1e-7)
        # model.to_variational_dropout()
        if args.bproplen <= 20:
            configuration.config.user_memory_efficiency = 0
        else:
            configuration.config.user_memory_efficiency = 3

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # make the GPU current
        model.to_gpu()

    # Set up an optimizer
    if args.pretrain:
        optimizer = chainer.optimizers.SGD(lr=1.0)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))
        optimizer.add_hook(chainer.optimizer.GradientClipping(5.))
    else:
        optimizer = chainer.optimizers.Adam(alpha=1e-5)
        #optimizer = chainer.optimizers.SGD(lr=1.0)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(5.))

    # Set up a trainer
    updater = BPTTUpdater(train_iter,
                          optimizer,
                          args.bproplen,
                          args.gpu,
                          loss_func=model.calc_loss,
                          decay_iter=((n_iters * 6,
                                       n_iters) if args.pretrain else (0, 0)))
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Model with shared params and distinct states
    eval_model = L.Classifier(model.copy())
    eval_rnn = eval_model.predictor
    trainer.extend(
        extensions.Evaluator(
            val_iter,
            eval_model,
            device=args.gpu,
            # Reset the RNN state at the beginning of each evaluation
            eval_hook=lambda _: eval_rnn.reset_state()))

    interval = min(10 if args.test else 100, max(n_iters, 1))
    trainer.extend(
        extensions.LogReport(postprocess=compute_perplexity,
                             trigger=(interval, 'iteration')))

    if args.pretrain:
        trainer.extend(extensions.PrintReport([
            'epoch', 'iteration', 'perplexity', 'val_perplexity',
            'main/accuracy', 'validation/main/accuracy', 'main/lr',
            'elapsed_time'
        ]),
                       trigger=(interval, 'iteration'))
    else:
        trainer.extend(extensions.PrintReport([
            'epoch', 'iteration', 'perplexity', 'val_perplexity',
            'main/accuracy', 'validation/main/accuracy', 'main/class',
            'main/kl', 'main/mean_p', 'main/sparsity', 'main/W/Wnz',
            'main/kl_coef', 'main/lr', 'elapsed_time'
        ]),
                       trigger=(interval, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=1 if args.test else 10))
    trainer.extend(
        extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'))

    trainer.run()

    # Evaluate the final model
    print('test')
    eval_rnn.reset_state()
    evaluator = extensions.Evaluator(test_iter, eval_model, device=args.gpu)
    result = evaluator()
    print('test perplexity:', np.exp(float(result['main/loss'])))
Ejemplo n.º 27
0
def main():
    parser = argparse.ArgumentParser()
    # general configuration
    parser.add_argument('--gpu',
                        '-g',
                        default='-1',
                        type=str,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--outdir',
                        type=str,
                        required=True,
                        help='Output directory')
    parser.add_argument('--debugmode', default=1, type=int, help='Debugmode')
    parser.add_argument('--dict', required=True, help='Dictionary')
    parser.add_argument('--seed', default=1, type=int, help='Random seed')
    parser.add_argument('--debugdir',
                        type=str,
                        help='Output directory for debugging')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--minibatches',
                        '-N',
                        type=int,
                        default='-1',
                        help='Process only N minibatches (for debug)')
    parser.add_argument('--verbose',
                        '-V',
                        default=0,
                        type=int,
                        help='Verbose option')
    # task related
    parser.add_argument('--train-feat',
                        type=str,
                        required=True,
                        help='Filename of train feature data (Kaldi scp)')
    parser.add_argument('--valid-feat',
                        type=str,
                        required=True,
                        help='Filename of validation feature data (Kaldi scp)')
    parser.add_argument('--train-label',
                        type=str,
                        required=True,
                        help='Filename of train label data (json)')
    parser.add_argument('--valid-label',
                        type=str,
                        required=True,
                        help='Filename of validation label data (json)')
    # network archtecture
    # encoder
    parser.add_argument('--etype',
                        default='blstmp',
                        type=str,
                        choices=['blstmp', 'vggblstmp', 'vggblstm'],
                        help='Type of encoder network architecture')
    parser.add_argument('--elayers',
                        default=4,
                        type=int,
                        help='Number of encoder layers')
    parser.add_argument('--eunits',
                        '-u',
                        default=300,
                        type=int,
                        help='Number of encoder hidden units')
    parser.add_argument('--eprojs',
                        default=320,
                        type=int,
                        help='Number of encoder projection units')
    parser.add_argument(
        '--subsample',
        default=1,
        type=str,
        help=
        'Subsample input frames x_y_z means subsample every x frame at 1st layer, '
        'every y frame at 2nd layer etc.')
    # attention
    parser.add_argument('--atype',
                        default='dot',
                        type=str,
                        choices=['dot', 'location'],
                        help='Type of attention architecture')
    parser.add_argument('--adim',
                        default=320,
                        type=int,
                        help='Number of attention transformation dimensions')
    parser.add_argument('--aconv-chans',
                        default=-1,
                        type=int,
                        help='Number of attention convolution channels \
                        (negative value indicates no location-aware attention)'
                        )
    parser.add_argument('--aconv-filts',
                        default=100,
                        type=int,
                        help='Number of attention convolution filters \
                        (negative value indicates no location-aware attention)'
                        )
    # decoder
    parser.add_argument('--dtype',
                        default='lstm',
                        type=str,
                        choices=['lstm'],
                        help='Type of decoder network architecture')
    parser.add_argument('--dlayers',
                        default=1,
                        type=int,
                        help='Number of decoder layers')
    parser.add_argument('--dunits',
                        default=320,
                        type=int,
                        help='Number of decoder hidden units')
    parser.add_argument(
        '--mtlalpha',
        default=0.5,
        type=float,
        help=
        'Multitask learning coefficient, alpha: alpha*ctc_loss + (1-alpha)*att_loss '
    )
    # model (parameter) related
    parser.add_argument('--dropout-rate',
                        default=0.0,
                        type=float,
                        help='Dropout rate')
    # minibatch related
    parser.add_argument('--batch-size',
                        '-b',
                        default=50,
                        type=int,
                        help='Batch size')
    parser.add_argument(
        '--maxlen-in',
        default=800,
        type=int,
        metavar='ML',
        help='Batch size is reduced if the input sequence length > ML')
    parser.add_argument(
        '--maxlen-out',
        default=150,
        type=int,
        metavar='ML',
        help='Batch size is reduced if the output sequence length > ML')
    # optimization related
    parser.add_argument('--opt',
                        default='adadelta',
                        type=str,
                        choices=['adadelta', 'adam'],
                        help='Optimizer')
    parser.add_argument('--eps',
                        default=1e-8,
                        type=float,
                        help='Epsilon constant for optimizer')
    parser.add_argument('--eps-decay',
                        default=0.01,
                        type=float,
                        help='Decaying ratio of epsilon')
    parser.add_argument('--criterion',
                        default='acc',
                        type=str,
                        choices=['loss', 'acc'],
                        help='Criterion to perform epsilon decay')
    parser.add_argument('--threshold',
                        default=1e-4,
                        type=float,
                        help='Threshold to stop iteration')
    parser.add_argument('--epochs',
                        '-e',
                        default=30,
                        type=int,
                        help='Number of maximum epochs')
    parser.add_argument('--grad-clip',
                        default=5,
                        type=float,
                        help='Gradient norm threshold to clip')
    args = parser.parse_args()

    # logging info
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
        logging.warning('Skip DEBUG/INFO messages')

    # display PYTHONPATH
    logging.info('python path = ' + os.environ['PYTHONPATH'])

    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    # seed setting (chainer seed may not need it)
    nseed = args.seed
    random.seed(nseed)
    np.random.seed(nseed)
    os.environ['CHAINER_SEED'] = str(nseed)
    logging.info('chainer seed = ' + os.environ['CHAINER_SEED'])

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('chainer type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        chainer.config.cudnn_deterministic = False
        logging.info('chainer cudnn deterministic is disabled')
    else:
        chainer.config.cudnn_deterministic = True
    # load dictionary for debug log
    if args.debugmode > 0 and args.dict is not None:
        with open(args.dict, 'r') as f:
            dictionary = f.readlines()
        char_list = [d.split(' ')[0] for d in dictionary]
        for i, char in enumerate(char_list):
            if char == '<space>':
                char_list[i] = ' '
        char_list.insert(0, '<sos>')
        char_list.append('<eos>')
        args.char_list = char_list
    else:
        args.char_list = None

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_label, 'r') as f:
        valid_json = json.load(f)['utts']
    utts = valid_json.keys()
    idim = int(valid_json[utts[0]]['idim'])
    odim = int(valid_json[utts[0]]['odim'])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = MTLLoss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.conf'
    with open(model_conf, 'w') as f:
        logging.info('writing a model config file to' + model_conf)
        # TODO use others than pickle, possibly json, and save as a text
        pickle.dump((idim, odim, args), f)
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    if args.gpu == 'jhu':
        # TODO make this one controlled at conf/gpu.conf or whatever
        # this is JHU CLSP cluster setup
        cmd = '/home/gkumar/scripts/free-gpu'
        p = subprocess.Popen(cmd,
                             shell=True,
                             stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)
        stdout_data, stderr_data = p.communicate()
        gpu_id = int(stdout_data.rstrip())
    else:
        gpu_id = int(args.gpu)
    logging.info('gpu id: ' + str(gpu_id))
    if gpu_id >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
    elif args.opt == 'adam':
        optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))

    # read json data
    with open(args.train_label, 'r') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_label, 'r') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    valid = make_batchset(valid_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.SerialIterator(train, 1)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  1,
                                                  repeat=False,
                                                  shuffle=False)

    # prepare Kaldi reader
    train_reader = kaldi_io.RandomAccessBaseFloatMatrixReader(args.train_feat)
    valid_reader = kaldi_io.RandomAccessBaseFloatMatrixReader(args.valid_feat)

    # Set up a trainer
    updater = SeqUpdaterKaldi(train_iter, optimizer, train_reader, gpu_id)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        SeqEvaluaterKaldi(valid_iter, model, valid_reader, device=gpu_id))

    # Take a snapshot for each specified epoch
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(
        extensions.snapshot_object(model, 'model.acc.best'),
        trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').eps),
                       trigger=(100, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(100, 'iteration'))

    trainer.extend(extensions.ProgressBar())

    # Run the training
    trainer.run()
Ejemplo n.º 28
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]["input"][0]["shape"][-1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # Initialize with pre-trained ASR encoder and MT decoder
    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args, interface=STInterface)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, STInterface)

    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(args.char_list),
                rnnlm_args.layer,
                rnnlm_args.unit,
                getattr(rnnlm_args, "embed_unit",
                        None),  # for backward compatibility
            ))
        torch_load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                dtype=dtype,
                                asr_task=args.asr_weight > 0)

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
    )

    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = {
        "main":
        ChainerDataLoader(
            dataset=TransformDataset(train,
                                     lambda data: converter([load_tr(data)])),
            batch_size=1,
            num_workers=args.n_iter_processes,
            shuffle=not use_sortagrad,
            collate_fn=lambda x: x[0],
        )
    }
    valid_iter = {
        "main":
        ChainerDataLoader(
            dataset=TransformDataset(valid,
                                     lambda data: converter([load_cv(data)])),
            batch_size=1,
            shuffle=False,
            collate_fn=lambda x: x[0],
            num_workers=args.n_iter_processes,
        )
    }

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        train_iter,
        optimizer,
        device,
        args.ngpu,
        args.grad_noise,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, valid_iter, reporter, device, args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, valid_iter, reporter, device, args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0:
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["input"][0]["shape"][1]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "validation/main/loss",
                "main/loss_asr",
                "validation/main/loss_asr",
                "main/loss_st",
                "validation/main/loss_st",
            ],
            "epoch",
            file_name="loss.png",
        ))
    trainer.extend(
        extensions.PlotReport(
            [
                "main/acc",
                "validation/main/acc",
                "main/acc_asr",
                "validation/main/acc_asr",
            ],
            "epoch",
            file_name="acc.png",
        ))
    trainer.extend(
        extensions.PlotReport(["main/bleu", "validation/main/bleu"],
                              "epoch",
                              file_name="bleu.png"))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    trainer.extend(
        snapshot_object(model, "model.acc.best"),
        trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
    )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
    elif args.opt == "adam":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "main/loss_st",
        "main/loss_asr",
        "validation/main/loss",
        "validation/main/loss_st",
        "validation/main/loss_asr",
        "main/acc",
        "validation/main/acc",
    ]
    if args.asr_weight > 0:
        report_keys.append("main/acc_asr")
        report_keys.append("validation/main/acc_asr")
    report_keys += ["elapsed_time"]
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    elif args.opt in ["adam", "noam"]:
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["lr"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    if args.asr_weight > 0:
        if args.mtlalpha > 0:
            report_keys.append("main/cer_ctc")
            report_keys.append("validation/main/cer_ctc")
        if args.mtlalpha < 1:
            if args.report_cer:
                report_keys.append("validation/main/cer")
            if args.report_wer:
                report_keys.append("validation/main/wer")
    if args.report_bleu:
        report_keys.append("validation/main/bleu")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                              att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Ejemplo n.º 29
0
def main():
    parser = argparse.ArgumentParser(
        description='chainer line drawing colorization')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=4,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset',
                        '-i',
                        default='./images/',
                        help='Directory of image files.')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    root = args.dataset
    #model = "./model_paint"

    cnn = unet.UNET()
    #serializers.load_npz("result/model_iter_10000", cnn)
    cnn_128 = unet.UNET()
    serializers.load_npz("models/model_cnn_128_dfl2_9", cnn_128)

    dataset = Image2ImageDatasetX2("dat/images_color_train.dat",
                                   root + "linex2/",
                                   root + "colorx2/",
                                   train=True)
    #dataset.set_img_dict(img_dict)
    train_iter = chainer.iterators.SerialIterator(dataset, args.batchsize)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        cnn.to_gpu()  # Copy the model to the GPU
        cnn_128.to_gpu()  # Copy the model to the GPU

    # Setup optimizer parameters.
    opt = optimizers.Adam(alpha=0.0001)
    opt.setup(cnn)
    opt.add_hook(chainer.optimizer.WeightDecay(1e-5), 'hook_cnn')

    # Set up a trainer
    updater = ganUpdater(models=(cnn, cnn_128),
                         iterator={
                             'main': train_iter,
                         },
                         optimizer={'cnn': opt},
                         device=args.gpu)

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    snapshot_interval2 = (args.snapshot_interval * 2, 'iteration')
    trainer.extend(extensions.dump_graph('cnn/loss'))
    trainer.extend(extensions.snapshot(), trigger=snapshot_interval2)
    trainer.extend(extensions.snapshot_object(
        cnn, 'cnn_x2_iter_{.updater.iteration}'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(opt, 'optimizer_'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=(10, 'iteration'), ))
    trainer.extend(
        extensions.PrintReport(['epoch', 'cnn/loss', 'cnn/loss_rec']))
    trainer.extend(extensions.ProgressBar(update_interval=20))

    trainer.run()

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Save the trained model
    chainer.serializers.save_npz(os.path.join(out_dir, 'model_final'), cnn)
    chainer.serializers.save_npz(os.path.join(out_dir, 'optimizer_final'), opt)
Ejemplo n.º 30
0
def main():
    parser = argparse.ArgumentParser(description='GAN_MNIST')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=200,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=30,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--z_dim',
                        '-z',
                        default=2,
                        help='Dimension of random variable')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# z_dim: {}'.format(args.z_dim))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    gen = Generator(args.z_dim)
    dis = Discriminator()
    gen.to_gpu()
    dis.to_gpu()

    opt = {
        'gen': optimizers.Adam(alpha=-0.001, beta1=0.5),  # alphaの符号が重要
        'dis': optimizers.Adam(alpha=0.001, beta1=0.5)
    }
    opt['gen'].setup(gen)
    opt['dis'].setup(dis)

    train, test = datasets.get_mnist(withlabel=False, ndim=3)

    train_iter = iterators.SerialIterator(train, batch_size=args.batchsize)

    updater = GAN_Updater(train_iter,
                          gen,
                          dis,
                          opt,
                          device=args.gpu,
                          z_dim=args.z_dim)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.dump_graph('loss'))
    trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport(['epoch', 'loss', 'loss_gen', 'loss_data']))
    trainer.extend(extensions.ProgressBar(update_interval=100))

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()

    np.save('x_gen.npy', cuda.to_cpu(x_gen.data))
    save_x(x_gen)