def main(*args, **kwargs):
    parser = argparse.ArgumentParser(
        description=
        "Argument parser for the main module. Main module represents train procedure."
    )
    parser.add_argument(
        "--root-dir",
        type=str,
        required=True,
        help="Path to the root dir where will be stores models.")
    parser.add_argument(
        "--dataset-path",
        type=str,
        required=True,
        help=
        "Path to the KITTI dataset which contains 'testing' and 'training' subdirs."
    )
    parser.add_argument("--fold",
                        type=int,
                        default=1,
                        help="Num of a validation fold.")

    #optimizer options
    parser.add_argument("--optim",
                        type=str,
                        default="SGD",
                        help="Type of optimizer: SGD or Adam")
    parser.add_argument("--lr",
                        type=float,
                        default=1e-3,
                        help="Learning rates for optimizer.")
    parser.add_argument("--momentum",
                        type=float,
                        default=0.9,
                        help="Momentum for SGD optim.")

    #Scheduler options
    parser.add_argument("--scheduler",
                        type=str,
                        default="multi-step",
                        help="Type of a scheduler for LR scheduling.")
    parser.add_argument("--step-st",
                        type=int,
                        default=5,
                        help="Step size for StepLR scheudle.")
    parser.add_argument("--milestones",
                        type=str,
                        default="30,70,90",
                        help="List with milestones for MultiStepLR schedule.")
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.1,
        help="Gamma parameter for StepLR and MultiStepLR schedule.")
    parser.add_argument(
        "--patience",
        type=int,
        default=5,
        help="Patience parameter for ReduceLROnPlateau schedule.")

    #model params
    parser.add_argument("--model-type",
                        type=str,
                        default="reknetm1",
                        help="Type of model. Can be 'RekNetM1' or 'RekNetM2'.")
    parser.add_argument(
        "--decoder-type",
        type=str,
        default="up",
        help=
        "Type of decoder module. Can be 'up'(Upsample) or 'ConvTranspose2D'.")
    parser.add_argument("--init-type",
                        type=str,
                        default="He",
                        help="Initialization type. Can be 'He' or 'Xavier'.")
    parser.add_argument("--act-type",
                        type=str,
                        default="relu",
                        help="Activation type. Can be ReLU, CELU or FTSwish+.")
    parser.add_argument("--enc-bn-enable",
                        type=int,
                        default=1,
                        help="Batch normalization enabling in encoder module.")
    parser.add_argument("--dec-bn-enable",
                        type=int,
                        default=1,
                        help="Batch normalization enabling in decoder module.")
    parser.add_argument("--skip-conn",
                        type=int,
                        default=0,
                        help="Skip-connection in context module.")
    parser.add_argument("--attention",
                        type=int,
                        default=0,
                        help="Attention mechanism in context module.")

    #other options
    parser.add_argument("--n-epochs",
                        type=int,
                        default=100,
                        help="Number of training epochs.")
    parser.add_argument("--batch-size",
                        type=int,
                        default=4,
                        help="Number of examples per batch.")
    parser.add_argument("--num-workers",
                        type=int,
                        default=8,
                        help="Number of loading workers.")
    parser.add_argument("--device-ids",
                        type=str,
                        default="0",
                        help="ID of devices for multiple GPUs.")
    parser.add_argument("--alpha",
                        type=float,
                        default=0,
                        help="Modulation factor for custom loss.")
    parser.add_argument("--status-every",
                        type=int,
                        default=1,
                        help="Status every parameter.")

    args = parser.parse_args()

    #Console logger definition
    console_logger = logging.getLogger("console-logger")
    console_logger.setLevel(logging.INFO)
    ch = logging.StreamHandler(stream=sys.stdout)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    ch.setFormatter(formatter)
    console_logger.addHandler(ch)

    console_logger.info(args)

    #number of classes
    num_classes = 1

    if args.decoder_type == "up":
        upsample_enable = True
        console_logger.info("Decoder type is Upsample.")
    elif args.decoder_type == "convTr":
        upsample_enable = False
        console_logger.info("Decoder type is ConvTranspose2D.")

    #Model definition
    if args.model_type == "reknetm1":
        model = RekNetM1(num_classes=num_classes,
                         ebn_enable=bool(args.enc_bn_enable),
                         dbn_enable=bool(args.dec_bn_enable),
                         upsample_enable=upsample_enable,
                         act_type=args.act_type,
                         init_type=args.init_type)
        console_logger.info("Uses RekNetM1 as the model.")
    elif args.model_type == "reknetm2":
        model = RekNetM2(num_classes=num_classes,
                         ebn_enable=bool(args.enc_bn_enable),
                         dbn_enable=bool(args.dec_bn_enable),
                         upsample_enable=upsample_enable,
                         act_type=args.act_type,
                         init_type=args.init_type,
                         attention=bool(args.attention),
                         use_skip=bool(args.skip_conn))
        console_logger.info("Uses RekNetM2 as the model.")
    elif args.model_type == "lcn":
        model = LidCamNet(num_classes=num_classes, bn_enable=False)
        console_logger.info("Uses LinCamNet as the model.")
    else:
        raise ValueError("Unknown model type: {}".format(args.model_type))

    console_logger.info("Number of trainable parameters: {}".format(
        utils.count_params(model)[1]))

    #Move model to devices
    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()
    cudnn.benchmark = True

    #Loss definition
    loss = BCEJaccardLoss(alpha=args.alpha)

    dataset_path = Path(args.dataset_path)
    images = str(dataset_path / "training" / droped_valid_image_2_dir)
    masks = str(dataset_path / "training" / train_masks_dir)

    #train-val splits for cross-validation by a fold
    ((train_imgs, train_masks),
     (valid_imgs, valid_masks)) = crossval_split(images_paths=images,
                                                 masks_paths=masks,
                                                 fold=args.fold)

    train_dataset = RoadDataset2(img_paths=train_imgs,
                                 mask_paths=train_masks,
                                 transforms=train_transformations())
    valid_dataset = RoadDataset2(img_paths=valid_imgs,
                                 mask_paths=valid_masks,
                                 transforms=valid_tranformations())
    valid_fmeasure_datset = RoadDataset2(img_paths=valid_imgs,
                                         mask_paths=valid_masks,
                                         transforms=valid_tranformations(),
                                         fmeasure_eval=True)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=torch.cuda.is_available())
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=torch.cuda.device_count(),
                              num_workers=args.num_workers,
                              pin_memory=torch.cuda.is_available())

    console_logger.info("Train dataset length: {}".format(len(train_dataset)))
    console_logger.info("Validation dataset length: {}".format(
        len(valid_dataset)))

    #Optim definition
    if args.optim == "SGD":
        optim = SGD(params=model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum)
        console_logger.info(
            "Uses the SGD optimizer with initial lr={0} and momentum={1}".
            format(args.lr, args.momentum))
    else:
        optim = Adam(params=model.parameters(), lr=args.lr)
        console_logger.info(
            "Uses the Adam optimizer with initial lr={0}".format(args.lr))

    if args.scheduler == "step":
        lr_scheduler = StepLR(optimizer=optim,
                              step_size=args.step_st,
                              gamma=args.gamma)
        console_logger.info(
            "Uses the StepLR scheduler with step={} and gamma={}.".format(
                args.step_st, args.gamma))
    elif args.scheduler == "multi-step":
        lr_scheduler = MultiStepLR(
            optimizer=optim,
            milestones=[int(m) for m in (args.milestones).split(",")],
            gamma=args.gamma)
        console_logger.info(
            "Uses the MultiStepLR scheduler with milestones=[{}] and gamma={}."
            .format(args.milestones, args.gamma))
    elif args.scheduler == "rlr-plat":
        lr_scheduler = ReduceLROnPlateau(optimizer=optim,
                                         patience=args.patience,
                                         verbose=True)
        console_logger.info("Uses the ReduceLROnPlateau scheduler.")
    elif args.scheduler == "poly":
        lr_scheduler = PolyLR(optimizer=optim,
                              num_epochs=args.n_epochs,
                              alpha=args.gamma)
        console_logger.info("Uses the PolyLR scheduler.")
    else:
        raise ValueError("Unknown type of schedule: {}".format(args.scheduler))

    valid = utils.binary_validation_routine

    utils.train_routine(args=args,
                        console_logger=console_logger,
                        root=args.root_dir,
                        model=model,
                        criterion=loss,
                        optimizer=optim,
                        scheduler=lr_scheduler,
                        train_loader=train_loader,
                        valid_loader=valid_loader,
                        fm_eval_dataset=valid_fmeasure_datset,
                        validation=valid,
                        fold=args.fold,
                        num_classes=num_classes,
                        n_epochs=args.n_epochs,
                        status_every=args.status_every)
Exemple #2
0
def main():
    args = parse_args()
    if args.name is None:
        if args.deepsupervision:
            args.name = '%s_%s_%s_withDS' % (args.dataset, args.arch,
                                             args.loss)
        else:
            args.name = '%s_%s_%s_withoutDS' % (args.dataset, args.arch,
                                                args.loss)
    if not os.path.exists('trained_models/%s' % args.name):
        os.makedirs('trained_models/%s' % args.name)
    # 记录参数到文件
    print('Config --------')
    for arg in vars(args):
        print('%s,%s' % (arg, getattr(args, arg)))
    print('---------------')

    with open("trained_models/%s/args.txt" % args.name, 'w') as f:
        for arg in vars(args):
            print('%s,%s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'trained_models/%s/args.pkl' % args.name)

    # 定义损失函数
    if args.loss == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[args.loss]().cuda()
    # 提升效率
    cudnn.benchmark = True

    # 数据集载入
    img_paths = glob(r'F:\Verse_Data\train_data_256x256\img\*')
    mask_paths = glob(r'F:\Verse_Data\train_data_256x256\mask\*')
    train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = \
        train_test_split(img_paths, mask_paths, test_size=0.2, random_state=41)
    print("train_nums:%s" % str(len(train_img_paths)))
    print("val_nums:%s" % str(len(val_img_paths)))

    # 创建模型
    print("=> creating model: %s " % args.arch)
    # 修改此处,即为修改模型
    trainModel = DAUnetModel.__dict__[args.arch]()
    trainModel = trainModel.cuda()
    params_model = count_params(trainModel) / (1024 * 1024)
    print("参数:%.2f" % (params_model) + "MB")
    with open("trained_models/%s/args.txt" % args.name, 'a') as f:
        print('params-count:%s' % (params_model) + "MB", file=f)

    if args.optimizer == 'Adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      trainModel.parameters()),
                               lr=args.lr)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     trainModel.parameters()),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay,
                              nesterov=args.nesterov)

    train_dataset = VerseDataset(args, train_img_paths, train_mask_paths,
                                 args.aug)
    val_dataset = VerseDataset(args, val_img_paths, val_mask_paths, args.aug)

    # drop_last扔掉最后一个batch_size剩下的data
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             drop_last=False)

    log = pd.DataFrame(
        index=[],
        columns=['epoch', 'lr', 'loss', 'iou', 'val_loss', 'val_iou'])
    best_iou = 0
    trigger = 0
    for epoch in range(args.epochs):
        print('Epoch [%d/%d]' % (epoch, args.epochs))

        # train
        train_log = train(args, train_loader, trainModel, criterion, optimizer,
                          epoch)
        # val
        val_log = validate(args, val_loader, trainModel, criterion)

        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f' %
              (train_log['loss'], train_log['iou'], val_log['loss'],
               val_log['iou']))

        tmp = pd.Series(
            [
                epoch,
                args.lr,
                train_log['loss'],
                train_log['iou'],
                val_log['loss'],
                val_log['iou'],
            ],
            index=['epoch', 'lr', 'loss', 'iou', 'val_loss', 'val_iou'])

        log = log.append(tmp, ignore_index=True)
        log.to_csv('trained_models/%s/log.csv' % args.name, index=False)
        trigger += 1

        if val_log['iou'] > best_iou:
            torch.save(trainModel.state_dict(),
                       'trained_models/%s/model.pth' % args.name)
            best_iou = val_log['iou']
            print('=> saved best model')
            # 并保持当前最好的保存checkpoint
            checkpoint = {
                "model_state_dict": trainModel.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch
            }
            path_checkpoint = "trained_models/%s/checkpoint_%d_epoch.pkl" % (
                args.name, epoch)
            torch.save(checkpoint, path_checkpoint)
            trigger = 0
        # early stopping
        if not args.early_stop is None:
            if trigger >= args.early_stop:
                print("=> early stopping")
                break

        torch.cuda.empty_cache()
    def test(self):
        # load data
        args = self.args
        batchedData, maxTimeSteps, totalN = self.load_data(args, mode='test', type=args.level)
        model = model_functions[args.model](args, maxTimeSteps)

        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        with tf.Session(graph=model.graph) as sess:
            ckpt = tf.train.get_checkpoint_state(args.save_dir)
            if ckpt and ckpt.model_checkpoint_path:
                model.saver.restore(sess, ckpt.model_checkpoint_path)
                print('Model restored from:' + args.save_dir)

            batchErrors = np.zeros(len(batchedData))
            batchRandIxs = np.random.permutation(len(batchedData))
            for batch, batchOrigI in enumerate(batchRandIxs):
                batchInputs, batchTargetSparse, batchSeqLengths = batchedData[batchOrigI]
                batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                feedDict = {model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths}

                if args.level == 'cha':
                    l, pre, y, er = sess.run([model.loss,
                                              model.predictions,
                                              model.targetY,
                                              model.errorRate],
                                             feed_dict=feedDict)
                    batchErrors[batch] = er
                    print('\ntotal:{},batch:{}/{},loss={:.3f},mean CER={:.3f}\n'.format(
                        totalN,
                        batch + 1,
                        len(batchRandIxs),
                        l,
                        er / args.batch_size))

                elif args.level == 'phn':
                    l, pre, y = sess.run([model.loss,
                                          model.predictions,
                                          model.targetY],
                                         feed_dict=feedDict)
                    er = get_edit_distance([pre.values], [y.values], True, 'test', args.level)
                    print('\ntotal:{},batch:{}/{},loss={:.3f},mean PER={:.3f}\n'.format(
                        totalN,
                        batch + 1,
                        len(batchRandIxs),
                        l,
                        er / args.batch_size))
                    batchErrors[batch] = er * len(batchSeqLengths)

                print('Truth:\n' + output_to_sequence(y, type=args.level))
                print('Output:\n' + output_to_sequence(pre, type=args.level))

                '''
                l, pre, y = sess.run([ model.loss,
                                        model.predictions,
                                        model.targetY],
                                        feed_dict=feedDict)


                er = get_edit_distance([pre.values], [y.values], True, 'test', args.level)
                print(output_to_sequence(y,type=args.level))
                print(output_to_sequence(pre,type=args.level))
                '''
                with open(args.task + '_result.txt', 'a') as result:
                    result.write(output_to_sequence(y, type=args.level) + '\n')
                    result.write(output_to_sequence(pre, type=args.level) + '\n')
                    result.write('\n')
            epochER = batchErrors.sum() / totalN
            print(args.task + ' test error rate:', epochER)
            logging(model, self.logfile, epochER, mode='test')
                                                         100000,
                                                         0.96,
                                                         staircase=True)
    # Passing global_step to minimize() will increment it at each step.
    # learning_step = (tf.compat.v1.train.GradientDescentOptimizer(learning_rate).minimize(...my loss..., global_step=global_step))
    # opt_0 = tf.compat.v1.train.RMSPropOptimizer(learning_rate=learning_rate, decay=0.995).minimize(loss, var_list=[var for var in tf.trainable_variables()], global_step = global_step)
    opt = tf.compat.v1.train.AdamOptimizer(
        learning_rate=learning_rate, name='Adam').minimize(
            loss,
            var_list=[var for var in tf.trainable_variables()],
            global_step=global_step)

saver = tf.train.Saver(max_to_keep=1000)
sess.run(tf.global_variables_initializer())

utils.count_params()

# If a pre-trained ResNet is required, load the weights.
# This must be done AFTER the variables are initialized with sess.run(tf.global_variables_initializer())
if init_fn is not None:
    init_fn(sess)

# Load a previous checkpoint if desired
path = args.dataset
# folder_dataset=path.split('/')[-2]
checkpoints_path = "%s/%s" % ("checkpoints", args.checkpoint)
model_checkpoint_name = checkpoints_path + "/latest_model_" + args.model + "_" + ".ckpt"
# model_checkpoint_name = "checkpoints/latest_model_" + args.model + "_" + folder_dataset + ".ckpt"
if args.continue_training and os.path.isdir(checkpoints_path):
    print('Loaded latest model checkpoint')
    saver.restore(sess, model_checkpoint_name)
    def train(self):
        # load data
        args = self.args
        batchedData, maxTimeSteps, totalN = self.load_data(args, mode='train', type=args.level)
        model = model_functions[args.model](args, maxTimeSteps)

        # count the num of params
        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        print(model.config)

        with tf.Session(graph=model.graph) as sess:
            # restore from stored model
            if args.keep == True:
                ckpt = tf.train.get_checkpoint_state(args.save_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + args.save_dir)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            for epoch in range(args.num_epoch):
                ## training
                start = time.time()
                print('Epoch', epoch + 1, '...')
                batchErrors = np.zeros(len(batchedData))
                batchRandIxs = np.random.permutation(len(batchedData))

                for batch, batchOrigI in enumerate(batchRandIxs):
                    batchInputs, batchTargetSparse, batchSeqLengths = batchedData[batchOrigI]
                    batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                    feedDict = {model.inputX: batchInputs, model.targetIxs: batchTargetIxs,
                                model.targetVals: batchTargetVals, model.targetShape: batchTargetShape,
                                model.seqLengths: batchSeqLengths}

                    if args.level == 'cha':
                        _, l, pre, y, er = sess.run([model.optimizer, model.loss,
                                                     model.predictions,
                                                     model.targetY,
                                                     model.errorRate],
                                                    feed_dict=feedDict)
                        batchErrors[batch] = er
                        print('\ntotal:{},batch:{}/{},epoch:{}/{},loss={:.3f},mean CER={:.3f}\n'.format(
                            totalN,
                            batch + 1,
                            len(batchRandIxs),
                            epoch + 1,
                            args.num_epoch,
                            l,
                            er / args.batch_size))

                    elif args.level == 'phn':
                        _, l, pre, y = sess.run([model.optimizer, model.loss,
                                                 model.predictions,
                                                 model.targetY],
                                                feed_dict=feedDict)
                        er = get_edit_distance([pre.values], [y.values], True, 'train', args.level)
                        print('\ntotal:{},batch:{}/{},epoch:{}/{},loss={:.3f},mean PER={:.3f}\n'.format(
                            totalN,
                            batch + 1,
                            len(batchRandIxs),
                            epoch + 1,
                            args.num_epoch,
                            l,
                            er / args.batch_size))
                        batchErrors[batch] = er * len(batchSeqLengths)

                    # NOTE:
                    if er / args.batch_size == 1.0:
                        break

                    if batch % 30 == 0:
                        print('Truth:\n' + output_to_sequence(y, type=args.level))
                        print('Output:\n' + output_to_sequence(pre, type=args.level))

                    if (args.save == True) and ((epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or (
                                    epoch == args.num_epoch - 1 and batch == len(batchRandIxs) - 1)):
                        checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                        model.saver.save(sess, checkpoint_path, global_step=epoch)
                        print('Model has been saved in file')
                end = time.time()
                delta_time = end - start
                print('Epoch ' + str(epoch + 1) + ' needs time:' + str(delta_time) + ' s')

                if args.save == True and (epoch + 1) % 1 == 0:
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    model.saver.save(sess, checkpoint_path, global_step=epoch)
                    print('Model has been saved in file')
                epochER = batchErrors.sum() / totalN
                print('Epoch', epoch + 1, 'mean train error rate:', epochER)
                logging(model, self.logfile, epochER, epoch, delta_time, mode='config')
                logging(model, self.logfile, epochER, epoch, delta_time, mode='train')
Exemple #6
0
input_exposure_stacks = [
    tf.placeholder(tf.float32, shape=[None, None, None, 6]) for x in range(3)
]
gt_exposure_stack = tf.placeholder(tf.float32, shape=[None, None, None, 3])

lr = tf.placeholder("float", shape=[])
network, init_fn = fusion_model_builder.build_model(
    model_name=args.model,
    frontend=args.frontend,
    input_exposure_stack=input_exposure_stacks,
    crop_width=args.crop_width,
    crop_height=args.crop_height,
    is_training=True)

str_params = utils.count_params()
print(str_params)
if args.save_logs:
    log_file.write(str_params + "\n")

if args.loss == 'l2':
    loss = tf.losses.mean_squared_error(log_tonemap(gt_exposure_stack),
                                        log_tonemap(network))
elif args.loss == 'l1':
    loss = tf.losses.absolute_difference(log_tonemap(gt_exposure_stack),
                                         log_tonemap(network))

opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(
    loss, var_list=[var for var in tf.trainable_variables()])

saver = tf.train.Saver(max_to_keep=1000)
Exemple #7
0
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        feature_dirs, label_dirs = get_data(datadir, level, train_dataset,
                                            dev_dataset, test_dataset, mode)
        batchedData, maxTimeSteps, totalN = self.load_data(
            feature_dirs[0], label_dirs[0], mode, level)
        model = model_fn(args, maxTimeSteps)

        for feature_dir, label_dir in zip(feature_dirs, label_dirs):
            id_dir = feature_dirs.index(feature_dir)
            print('dir id:{}'.format(id_dir))
            batchedData, maxTimeSteps, totalN = self.load_data(
                feature_dir, label_dir, mode, level)
            model = model_fn(args, maxTimeSteps)
            num_params = count_params(model, mode='trainable')
            all_num_params = count_params(model, mode='all')
            model.config['trainable params'] = num_params
            model.config['all params'] = all_num_params
            print(model.config)
            with tf.Session(graph=model.graph) as sess:
                # restore from stored model
                if keep == True:
                    ckpt = tf.train.get_checkpoint_state(savedir)
                    if ckpt and ckpt.model_checkpoint_path:
                        model.saver.restore(sess, ckpt.model_checkpoint_path)
                        print('Model restored from:' + savedir)
                else:
                    print('Initializing')
                    sess.run(model.initial_op)

                for epoch in range(num_epochs):
                    ## training
                    start = time.time()
                    if mode == 'train':
                        print('Epoch {} ...'.format(epoch + 1))

                    batchErrors = np.zeros(len(batchedData))
                    batchRandIxs = np.random.permutation(len(batchedData))

                    for batch, batchOrigI in enumerate(batchRandIxs):
                        batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                            batchOrigI]
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                        feedDict = {
                            model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths
                        }

                        if level == 'cha':
                            if mode == 'train':
                                _, l, pre, y, er = sess.run([
                                    model.optimizer, model.loss,
                                    model.predictions, model.targetY,
                                    model.errorRate
                                ],
                                                            feed_dict=feedDict)

                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                    .format(level, totalN, id_dir + 1,
                                            len(feature_dirs), batch + 1,
                                            len(batchRandIxs), epoch + 1,
                                            num_epochs, l, er / batch_size))

                            elif mode == 'dev':
                                l, pre, y, er = sess.run([
                                    model.loss, model.predictions,
                                    model.targetY, model.errorRate
                                ],
                                                         feed_dict=feedDict)
                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},dev loss={:.3f},mean dev CER={:.3f}\n'
                                    .format(level, totalN, id_dir + 1,
                                            len(feature_dirs), batch + 1,
                                            len(batchRandIxs), l,
                                            er / batch_size))

                            elif mode == 'test':
                                l, pre, y, er = sess.run([
                                    model.loss, model.predictions,
                                    model.targetY, model.errorRate
                                ],
                                                         feed_dict=feedDict)
                                batchErrors[batch] = er
                                print(
                                    '\n{} mode, total:{},subdir:{}/{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'
                                    .format(level, totalN, id_dir + 1,
                                            len(feature_dirs), batch + 1,
                                            len(batchRandIxs), l,
                                            er / batch_size))
                        elif level == 'seq2seq':
                            raise ValueError('level %s is not supported now' %
                                             str(level))

                        # NOTE:
                        if er / batch_size == 1.0:
                            break

                        if batch % 30 == 0:
                            print('Truth:\n' +
                                  output_to_sequence(y, type=level))
                            print('Output:\n' +
                                  output_to_sequence(pre, type=level))

                        if mode == 'train' and (
                            (epoch * len(batchRandIxs) + batch + 1) % 20 == 0
                                or (epoch == num_epochs - 1
                                    and batch == len(batchRandIxs) - 1)):
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))

                    end = time.time()
                    delta_time = end - start
                    print('Epoch ' + str(epoch + 1) + ' needs time:' +
                          str(delta_time) + ' s')

                    if mode == 'train':
                        if (epoch + 1) % 1 == 0:
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))
                        epochER = batchErrors.sum() / totalN
                        print('Epoch', epoch + 1, 'mean train error rate:',
                              epochER)
                        logging(model,
                                logfile,
                                epochER,
                                epoch,
                                delta_time,
                                mode='config')
                        logging(model,
                                logfile,
                                epochER,
                                epoch,
                                delta_time,
                                mode=mode)

                    if mode == 'test' or mode == 'dev':
                        with open(
                                os.path.join(resultdir, level + '_result.txt'),
                                'a') as result:
                            result.write(
                                output_to_sequence(y, type=level) + '\n')
                            result.write(
                                output_to_sequence(pre, type=level) + '\n')
                            result.write('\n')
                        epochER = batchErrors.sum() / totalN
                        print(' test error rate:', epochER)
                        logging(model, logfile, epochER, mode=mode)
Exemple #8
0
    def train(self, args):
        ''' import data, train model, save model
	'''
        args.data_dir = args.data_dir + args.style + '/'
        args.save_dir = args.save_dir + args.style + '/'
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        print(args)
        if args.attention is True:
            print('attention mode')
        text_parser = TextParser(args)
        args.vocab_size = text_parser.vocab_size
        self.word_embedding_file = os.path.join(args.data_dir,
                                                "word_embedding.pkl")

        if args.pretrained is True:
            raise ValueError(
                'pretrained has bug now, so don"t set it to be True now!!!')
            if args.keep is False:
                raise ValueError(
                    'when pre-trained is True, keep must be true!')
            print("pretrained and keep mode...")
            print("restoring pretrained model file")
            ckpt = tf.train.get_checkpoint_state(
                "/home/pony/github/jaylyrics_generation_tensorflow/data/pre-trained/"
            )
            if os.path.exists(os.path.join("./data/pre-trained/",'config.pkl')) and \
         os.path.exists(os.path.join("./data/pre-trained/",'words_vocab.pkl')) and \
         ckpt and ckpt.model_checkpoint_path:
                with open(os.path.join("./data/pre-trained/", 'config.pkl'),
                          'rb') as f:
                    saved_model_args = cPickle.load(f)
                with open(
                        os.path.join("./data/pre-trained/", 'words_vocab.pkl'),
                        'rb') as f:
                    saved_words, saved_vocab = cPickle.load(f)
            else:
                raise ValueError('configuration doesn"t exist!')
        else:
            ckpt = tf.train.get_checkpoint_state(args.save_dir)

        if args.keep is True and args.pretrained is False:
            # check if all necessary files exist
            if os.path.exists(os.path.join(args.save_dir,'config.pkl')) and \
         os.path.exists(os.path.join(args.save_dir,'words_vocab.pkl')) and \
         ckpt and ckpt.model_checkpoint_path:
                with open(os.path.join(args.save_dir, 'config.pkl'),
                          'rb') as f:
                    saved_model_args = cPickle.load(f)
                with open(os.path.join(args.save_dir, 'words_vocab.pkl'),
                          'rb') as f:
                    saved_words, saved_vocab = cPickle.load(f)
            else:
                raise ValueError('configuration doesn"t exist!')

        if args.model == 'seq2seq_rnn':
            model = Model_rnn(args)
        else:
            pass

        trainable_num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        args.num_trainable_params = trainable_num_params
        args.num_all_params = all_num_params
        with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
            cPickle.dump(args, f)
        with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
            cPickle.dump((text_parser.vocab_dict, text_parser.vocab_list), f)

        with tf.Session() as sess:
            if args.keep is True:
                print('Restoring')
                model.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            sess.run(tf.assign(model.lr, args.learning_rate))
            for e in range(args.num_epochs):
                start = time.time()
                model.initial_state = tf.convert_to_tensor(model.initial_state)
                state = model.initial_state.eval()
                total_loss = []
                for b in range(text_parser.num_batches):
                    x, y = text_parser.next_batch()
                    if args.attention is True:
                        attention_states = sess.run(
                            tf.truncated_normal([
                                args.batch_size, model.attn_length,
                                model.attn_size
                            ],
                                                stddev=0.1,
                                                dtype=tf.float32))
                        feed = {
                            model.input_data: x,
                            model.targets: y,
                            model.initial_state: state,
                            model.attention_states: attention_states
                        }

                    else:
                        feed = {
                            model.input_data: x,
                            model.targets: y,
                            model.initial_state: state
                        }

                    train_loss, state, _, word_embedding = sess.run([
                        model.cost, model.final_state, model.train_op,
                        model.word_embedding
                    ], feed)
                    total_loss.append(train_loss)

                    print("{}/{} (epoch {}), train_loss = {:.3f}" \
                                .format(e * text_parser.num_batches + b, \
                                args.num_epochs * text_parser.num_batches, \
                                e, train_loss))

                    if (e * text_parser.num_batches +
                            b) % args.save_every == 0:
                        checkpoint_path = os.path.join(args.save_dir,
                                                       'model.ckpt')
                        model.saver.save(sess, checkpoint_path, global_step=e)
                        print("model has been saved in:" +
                              str(checkpoint_path))
                        np.save(self.word_embedding_file, word_embedding)
                        print("word embedding matrix has been saved in:" +
                              str(self.word_embedding_file))

                end = time.time()
                delta_time = end - start
                ave_loss = np.array(total_loss).mean()
                logging(model, ave_loss, e, delta_time, mode='train')
                if ave_loss < 0.1:
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    model.saver.save(sess, checkpoint_path, global_step=e)
                    print("model has been saved in:" + str(checkpoint_path))
                    np.save(self.word_embedding_file, word_embedding)
                    print("word embedding matrix has been saved in:" +
                          str(self.word_embedding_file))
                    break
Exemple #9
0
def main():
    args = parse_args()
    #args.dataset = "datasets"

    if args.name is None:
        if args.deepsupervision:
            args.name = '%s_%s_lym' % (args.dataset, args.arch)
        else:
            args.name = '%s_%s_lym' % (args.dataset, args.arch)
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    if not os.path.exists('models/{}/{}'.format(args.name, timestamp)):
        os.makedirs('models/{}/{}'.format(args.name, timestamp))

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' % (arg, getattr(args, arg)))
    print('------------')

    with open('models/{}/{}/args.txt'.format(args.name, timestamp), 'w') as f:
        for arg in vars(args):
            print('%s: %s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/{}/{}/args.pkl'.format(args.name, timestamp))

    # define loss function (criterion)
    if args.loss == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.BCEDiceLoss().cuda()

    cudnn.benchmark = True

    # Data loading code
    img_paths = glob('./data/train_image/*')
    mask_paths = glob('./data/train_mask/*')

    train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = \
        train_test_split(img_paths, mask_paths, test_size=0.3, random_state=39)
    print("train_num:%s" % str(len(train_img_paths)))
    print("val_num:%s" % str(len(val_img_paths)))

    # create model
    print("=> creating model %s" % args.arch)
    model = UNet.UNet3d(in_channels=1, n_classes=2, n_channels=32)
    model = torch.nn.DataParallel(model).cuda()
    #model._initialize_weights()
    #model.load_state_dict(torch.load('model.pth'))

    print(count_params(model))

    if args.optimizer == 'Adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay,
                              nesterov=args.nesterov)

    train_dataset = Dataset(args, train_img_paths, train_mask_paths, args.aug)
    val_dataset = Dataset(args, val_img_paths, val_mask_paths)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             drop_last=False)

    log = pd.DataFrame(index=[],
                       columns=[
                           'epoch', 'lr', 'loss', 'iou', 'dice_1', 'dice_2',
                           'val_loss', 'val_iou', 'val_dice_1', 'val_dice_2'
                       ])

    best_loss = 100
    # best_iou = 0
    trigger = 0
    first_time = time.time()
    for epoch in range(args.epochs):
        print('Epoch [%d/%d]' % (epoch, args.epochs))

        # train for one epoch
        train_log = train(args, train_loader, model, criterion, optimizer,
                          epoch)
        # evaluate on validation set
        val_log = validate(args, val_loader, model, criterion)

        print(
            'loss %.4f - iou %.4f - dice_1 %.4f - dice_2 %.4f - val_loss %.4f - val_iou %.4f - val_dice_1 %.4f - val_dice_2 %.4f'
            % (train_log['loss'], train_log['iou'], train_log['dice_1'],
               train_log['dice_2'], val_log['loss'], val_log['iou'],
               val_log['dice_1'], val_log['dice_2']))

        end_time = time.time()
        print("time:", (end_time - first_time) / 60)

        tmp = pd.Series([
            epoch,
            args.lr,
            train_log['loss'],
            train_log['iou'],
            train_log['dice_1'],
            train_log['dice_2'],
            val_log['loss'],
            val_log['iou'],
            val_log['dice_1'],
            val_log['dice_2'],
        ],
                        index=[
                            'epoch', 'lr', 'loss', 'iou', 'dice_1', 'dice_2',
                            'val_loss', 'val_iou', 'val_dice_1', 'val_dice_2'
                        ])

        log = log.append(tmp, ignore_index=True)
        log.to_csv('models/{}/{}/log.csv'.format(args.name, timestamp),
                   index=False)

        trigger += 1

        val_loss = val_log['loss']
        if val_loss < best_loss:
            torch.save(
                model.state_dict(),
                'models/{}/{}/epoch{}-{:.4f}-{:.4f}_model.pth'.format(
                    args.name, timestamp, epoch, val_log['dice_1'],
                    val_log['dice_2']))
            best_loss = val_loss
            print("=> saved best model")
            trigger = 0

        # early stopping
        if not args.early_stop is None:
            if trigger >= args.early_stop:
                print("=> early stopping")
                break

        torch.cuda.empty_cache()