示例#1
0
def get_local_protocol_config(data_name,
                              config,
                              logger,
                              regenerate_protocol_files=False):
    """
    get_dataset_config:
    :param data_name:
    :param config:
    :param regenerate_protocol_files:
    :return:
    """
    dataset_param = config.get_dataset_param(data_name)

    # regenerate the protocol files
    if regenerate_protocol_files:
        protocol_config_file = generate_protocol_files(data_name,
                                                       dataset_param, logger)
    else:
        protocol_config_file = dataset_param["eval_config_file"]
        logger.info("Employ protocol from '{:s}'".format(protocol_config_file))
        logger.info(separator_line())

    # read protocol params
    with open(protocol_config_file, "r") as pcf:
        protocol_config = json.load(pcf)

    return protocol_config
示例#2
0
def fetching_dataset(args, config, logger, subset):
    """
    fetching_dataset:
    :param args:
    :param config:
    :param subset:
    :return:
    """
    assert args.dataset in config.get_defined_datasets_list(
    ), "Unknown dataset: {:s}".format(args.dataset)

    if args.regenerate_protocol_files:
        protocol_config = get_local_protocol_config(args.dataset, config,
                                                    logger, subset == "train")
    else:
        protocol_config = get_local_protocol_config(args.dataset, config,
                                                    logger)

    index_param = {}

    index_param["file_format"] = protocol_config["file_format"][args.modality]
    index_param["file_subset_list"] = protocol_config["eval_Protocol"][
        args.eval_protocol][subset]
    index_param["subset_type"] = subset
    index_param["small_validation"] = args.small_validation
    index_param["modality"] = args.modality
    index_param["data_dosage"] = args.data_dosage
    index_param["is_gradient"] = args.is_gradient
    index_param["is_segmented"] = args.is_segmented
    index_param["logger"] = logger
    # index_param["dual_input"] = args.dual_input

    index_param["add_noise"] = args.add_noise
    # index_param["is_gradient_normal"] = args.is_gradient_normal
    index_param["use_depth_seq"] = args.use_depth_seq

    # configure transforms
    loader_param = config.get_loader_param(args.dataset, args.modality)

    spatial_param = loader_param["spatial_transform"][
        args.spatial_transform][subset]
    temporal_param = loader_param["temporal_transform"][
        args.temporal_transform][subset]

    index_param["temporal_param"] = temporal_param

    spatial_transform, temporal_transform = group_data_transforms(
        spatial_param, temporal_param, args.modality, args.is_gradient,
        args.is_gradient_normal, args.use_depth_seq)

    if args.dataset == 'MSRDailyAct3D':
        target_dataset = MSRDailyAct3D(index_param,
                                       spatial_transform=spatial_transform,
                                       temporal_transform=temporal_transform)

    elif args.dataset == "MSRAction3D":
        target_dataset = MSRAction3D(index_param,
                                     spatial_transform=spatial_transform,
                                     temporal_transform=temporal_transform)

    elif args.dataset == "NTU_RGB+D":
        target_dataset = NTURGBD(index_param,
                                 spatial_transform=spatial_transform,
                                 temporal_transform=temporal_transform)

    else:
        raise ValueError("Unknown dataset: '{:s}'".format(args.dataset))

    logger.info("[{:s}] spatial parameters of {:s} are: ".format(
        subset, args.spatial_transform))
    logger.info(json.dumps(spatial_param, indent=4))
    logger.info("[{:s}] temporal parameters of {:s} are: ".format(
        subset, args.temporal_transform))
    logger.info(json.dumps(temporal_param, indent=4))

    logger.info(separator_line())

    if subset == "train":
        data_to_fetch = torch.utils.data.DataLoader(
            target_dataset,
            batch_size=args.batch_size,
            shuffle=args.train_shuffle,
            num_workers=args.num_workers,
            pin_memory=args.pin_memory)
    else:
        data_to_fetch = torch.utils.data.DataLoader(
            target_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=args.pin_memory)

    if args.plot_confusion_matrix:
        data_to_fetch.class_names = protocol_config["action_Names"]
    sample_data, _, _ = target_dataset.__getitem__(0)

    return data_to_fetch, sample_data.shape
示例#3
0
def training_model(args, config, logger, train_loader, model_combined, val_loader):

    [model, metrics, criterion, optimizer] = model_combined

    # optionally resume from a checkpoint
    if args.resume:
        logger.info("=> Recovery model training from checkpoint")
        if os.path.isfile(args.resume):
            logger.info("{:s} loading checkpoint '{}'".format(datetime_now_string(), args.resume))
            checkpoint = torch.load(args.resume)

            # load param from checkpoint file
            start_epoch = checkpoint['epoch']
            best_prec = checkpoint['best_prec']
            best_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

            logger.info("{:s} loaded checkpoint '{}' (epoch {})".format(datetime_now_string(),
                                                                        args.resume, checkpoint['epoch']))

        else:
            logger.warning("Warning: No checkpoint found at '{}'".format(args.resume))

        logger.info(separator_line())

    else:
        start_epoch = 0
        best_prec = 0
        best_epoch = 0

    # get checkpoint saved directory
    checkpoint_root = config.get_path_param(args.dataset)["checkpoint_root"]
    ensure_directory(checkpoint_root)

    # logger.info(separator_line())
    logger.info("{:s} Start model training".format(datetime_now_string()))
    logger.info(separator_line())

    lr_initial = optimizer.param_groups[0]["lr"]

    # record epoch and batch logger with csv file
    train_epoch_logger = CSVLogger(train_epoch_csv_name.format(args.log_name) + ".temp",
                             ['epoch', 'train_loss', 'train_accuracy',
                              'validate_loss', 'validate_accuracy', 'lr'])

    train_batch_logger = CSVLogger(train_batch_csv_name.format(args.log_name) + ".temp",
                                   ['epoch', 'batch', 'iteration', 'loss', 'accuracy', 'lr'])

    # ---- early stop the training ----
    is_overfit_before = False
    is_overfit_count = 0

    is_loss_steady = False
    is_loss_count = 10
    is_loss_list = np.zeros(is_loss_count)

    # ---- start the epochs ----

    for epoch in range(start_epoch, args.epochs):

        logger.info(epochs_format(epoch, args.epochs))
        logger.info(separator_line())

        # if adjust the learning rate
        if args.adjust_lr is not None and args.adjust_lr not in "disable":
            adjust_lr_param = config.get_model_param()["adjust_lr"][args.adjust_lr]
            adjust_lr_param["lr_method"] = args.adjust_lr

            adjust_learning_rate(adjust_lr_param, lr_initial, optimizer, epoch)

        epoch_idx = [epoch, args.epochs]

        # train for one epoch
        train_prec, train_loss = train_epoch(args, logger, train_loader, model_combined,
                                             epoch_idx, train_batch_logger)

        # evaluate on validation set
        valid_prec, valid_loss = validate_model(args, logger, val_loader, model_combined, epoch)

        # remember the best predict accuracy and save the checkpoint
        is_best = valid_prec > best_prec
        best_prec = max(valid_prec, best_prec)

        is_check = epoch % args.checkpoint_interval == 0

        if is_best:
            logger.info("{:s} checkpoint at epoch: {:d} with accuracy: {:.2f}".format(datetime_now_string(),
                                                                                         epoch, best_prec))
            logger.info(separator_line())
            checkpoint_file_name = "{:s}/{:s}_{:2.0f}_{:s}_model_best.pth.tar".format(checkpoint_root, args.code,
                                                                                      args.data_dosage * 100,
                                                                                      args.net_arch)

            best_epoch = epoch

        else:
            checkpoint_file_name = "{:s}/{:s}_{:s}_model_checkpoint.pth.tar".format(checkpoint_root,
                                                                                    args.code,
                                                                                    args.net_arch)
        # save checkpoint
        if is_check or is_best:
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.net_arch,
                'state_dict': model.state_dict(),
                'best_prec': best_prec,
                'optimizer': optimizer.state_dict(),
            }, checkpoint_file_name)

        # csv epoch logger
        train_epoch_logger.log({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_accuracy': train_prec,
            'validate_loss': valid_loss,
            'validate_accuracy': valid_prec,
            'lr': optimizer.param_groups[0]['lr']
        })

        # tensorboard writer for epoch
        args.tb_writer.add_scalars('data/epoch_loss', {'train_loss': train_loss,
                                                 'validate_loss': valid_loss},
                                   epoch + 1)
        args.tb_writer.add_scalars('data/epoch_accuracy', {'train_accuracy': train_prec,
                                                 'validate_accuracy': valid_prec},
                                   epoch + 1)
        args.tb_writer.add_scalar('data/epoch_lr', optimizer.param_groups[0]['lr'], epoch + 1)


        # ---- break down epochs if the model is over fitting
        if train_prec > args.overfit_threshold[0] and valid_prec < args.overfit_threshold[1]:
            if is_overfit_before:
                is_overfit_count += 1
            else:
                is_overfit_count = 0

            is_overfit_before = True
            if is_overfit_count > args.overfit_threshold[2]:
                # generate the logs in advance
                logger.warning("Warning: Iteration has been terminated as the model is over fitting! "
                               "Best accuracy is {:.2f}% on epoch {:d}.".format(best_prec, best_epoch))
                # logger.info(separator_line())
                break
        else:
            is_overfit_before = False

        # ---- early stop for no loss declination
        is_loss_list = np.append(is_loss_list[-(is_loss_count - 1):], train_loss)

        if is_loss_list.std() < 1e-4:
            is_loss_steady = True
            logger.warning("Warning: Training Loss has been no more declining! "
                           "Best accuracy is {:.2f}% on epoch {:d}.".format(best_prec, best_epoch))
            # logger.info(separator_line())
            break
        else:
            is_loss_steady = False

    if not is_overfit_before and not is_loss_steady:
        logger.info("Total {:d} epochs of model training have finished, "
                    "best accuracy is {:.2f}% on epoch {:d}.".format(args.epochs, best_prec, best_epoch))
示例#4
0
def validate_model(args, logger, val_loader, model_combined, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracy = AverageMeter()
    samples_count = AverageMeter()
    samples_right = AverageMeter()

    [model, metrics, criterion, _] = model_combined

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        y_true = []
        y_pred = []
        for i, (input, target, _) in enumerate(val_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            if args.cuda is not None:
                input = input.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            y_true.extend(target.tolist())
            _, pred = output.topk(1, 1, True, True)
            y_pred.extend(pred.t().tolist()[0])

            # measure accuracy and record loss
            if "accuracy_percent" in metrics:
                predicted_accuracy, n_correct_elems = calculate_accuracy_percent(
                    output, target)
                samples_count.update(input.size(0))
                samples_right.update(n_correct_elems.item())
                accuracy.update(predicted_accuracy.item(), input.size(0))

            # Todo: add more metrics such as recall for special dataset

            losses.update(loss.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            """
            if i % args.log_interval == 0:
                print('Test: [{0}/{1}]--'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})--'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})--'
                      'Prec@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       acc=accuracy))
            """

        if args.plot_confusion_matrix:
            class_names = val_loader.class_names
            plt_fig = generate_confusion_matrix(y_true, y_pred, class_names)
            args.tb_writer.add_figure('confusion matrix', plt_fig, epoch)

        logger.info('=> Validate:  '
                    'Elapse: {data_time.sum:.2f}/{sum_time.sum:.2f}s  '
                    'Loss: {loss.avg:.4f}  '
                    'Accuracy: {acc.avg:.2f}% '
                    '[{right:.0f}/{count:.0f}]'.format(
                        loss=losses,
                        data_time=data_time,
                        sum_time=batch_time,
                        acc=accuracy,
                        right=samples_right.sum,
                        count=samples_count.sum))
        logger.info(separator_line())

    return accuracy.avg, losses.avg
示例#5
0
def generate_protocol_for_MSRDailyAct3D(dataset_param, logger):
    """
    generate_protocol_for_MSRDailyAct3D:
    :param dataset_param:
    :return:
    """
    data_Name = "MSRDailyAct3D"
    dataset_dir = dataset_param["data_dir"]

    data_attr = dataset_param["data_attr"]
    data_type = dataset_param["data_type"]

    protocols = list(dataset_param["eval_protocols"].keys())

    # AssertionError for none protocols
    assert len(protocols) > 0, "evaluation protocol should be declared!"

    # Preparing new dict for json file
    data_param_dict = {}
    data_param_dict["data_Name"] = data_Name
    data_param_dict["data_Path"] = dataset_dir
    data_param_dict["file_format"] = {}
    data_param_dict["file_format"]["Depth"] = dataset_param["data_type"][
        "Depth"]
    data_param_dict["file_format"]["Skeleton"] = dataset_param["data_type"][
        "Skeleton"]
    data_param_dict["file_format"]["RGB"] = dataset_param["data_type"]["RGB"]
    data_param_dict["file_format"]["Pre_gradient"] = dataset_param[
        "data_type"]["Pre_gradient"]
    data_param_dict["eval_Protocol"] = {}
    data_param_dict["action_Names"] = data_attr["n_names"]

    file_list_dir = dataset_param["file_list_dir"]

    if not os.path.exists(file_list_dir):
        os.makedirs(file_list_dir)

    for p_i, protocol_i in enumerate(protocols):

        protocol_param = dataset_param["eval_protocols"][protocol_i]
        train_list = []
        test_list = []

        action_list = []

        for n_si in range(1, data_attr["n_subs"] + 1):
            for n_ei in range(1, data_attr["n_exps"] + 1):
                for n_ai in range(1, data_attr["n_acts"] + 1):

                    file_str = dataset_param["file_format"].format(
                        n_ai, n_si, n_ei)

                    depth_file = data_type["Depth"].replace(
                        "<$file_format>", file_str)
                    skeleton_file = data_type["Skeleton"].replace(
                        "<$file_format>", file_str)

                    # constant check
                    if dataset_param["data_constant_check"]:
                        if os.path.isfile(depth_file) and os.path.isfile(
                                skeleton_file):
                            action_list.append(n_ai)

                            # get the depth temporal length
                            header_info = read_msr_depth_maps(depth_file,
                                                              seqs_idx=None,
                                                              header_only=True)
                            line_str = "{file_str:s}\t{label:d}\t{frames:d}".format(
                                file_str=file_str,
                                label=n_ai - 1,
                                frames=header_info[0])

                            if n_si in protocol_param["eval_subs"]:
                                test_list.append(line_str)
                            else:
                                train_list.append(line_str)

                    else:
                        if os.path.isfile(depth_file):
                            action_list.append(n_ai)

                            # get the depth temporal length
                            header_info = read_msr_depth_maps(depth_file,
                                                              seqs_idx=None,
                                                              header_only=True)
                            line_str = "{file_str:s}\t{label:d}\t{frames:d}".format(
                                file_str=file_str,
                                label=n_ai - 1,
                                frames=header_info[0])

                            if n_si in protocol_param["eval_subs"]:
                                test_list.append(line_str)
                            else:
                                train_list.append(line_str)

        # constant check for num classes
        num_classes = len(set(action_list))

        if num_classes != len(data_attr["n_names"]):
            logger.warn(
                "Warning: num classes: {:d} is not equal to class names: {}.".
                format(num_classes, len(data_attr["n_names"])))

        # protocol name
        protocol_item = "{:d}_{:s}".format(p_i + 1, protocol_i)

        logger.info(protocol_item + ": ")

        # Write train list to file
        train_list_file = protocol_param["file_list_train"].replace(
            "<$protocol_item>", protocol_item)
        with open(train_list_file, "w") as trlf:
            for train_line in train_list:
                trlf.write(train_line + "\n")
            trlf.close()
            logger.info("    Train filelist has been stored in '{:s}'".format(
                train_list_file))

        # Write test list to file
        test_list_file = protocol_param["file_list_test"].replace(
            "<$protocol_item>", protocol_item)
        with open(test_list_file, "w") as telf:
            for test_line in test_list:
                telf.write(test_line + "\n")
            telf.close()
            logger.info("    Test filelist has been stored in '{:s}'".format(
                test_list_file))

        logger.info(
            "    => Summary: {:d} samples for training and {:d} samples for test."
            .format(len(train_list), len(test_list)))
        logger.info("    => Number of classes: {:d}".format(num_classes))

        logger.info(separator_line(dis_len="half"))

        assert len(train_list) > 0 and len(
            test_list) > 0, "Target dataset has no samples to read."
        data_param_dict["eval_Protocol"][protocol_item] = {}
        data_param_dict["eval_Protocol"][protocol_item][
            "train"] = train_list_file
        data_param_dict["eval_Protocol"][protocol_item][
            "test"] = test_list_file

    # num_classes to eval_Protocol
    data_param_dict["num_classes"] = num_classes

    # write protocol param to json file
    data_param_dict_file = dataset_param["eval_config_file"].format(
        file_list_dir, data_Name)

    with open(data_param_dict_file, 'w') as jsf:
        json.dump(data_param_dict, jsf, indent=4)
        logger.info("Evaluation protocols have been stored in '{:s}'".format(
            data_param_dict_file))
        logger.info(separator_line())

    return data_param_dict_file
示例#6
0
def train_epoch(args, logger, train_loader, model_combined, epoch_idx, train_batch_logger):
    """
    train_epoch:
    :param args:
    :param logger:
    :param train_loader:
    :param model_combined:
    :param epoch_idx:
    :param train_batch_logger:
    :return:
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracy = AverageMeter()
    processed = AverageMeter()

    [model, metrics, criterion, optimizer] = model_combined
    [epoch, epochs] = epoch_idx

    # switch to train mode
    model.train()

    end = time.time()

    for i, (input, target, _) in enumerate(train_loader):
        # measure data loading time
        torch.cuda.synchronize()   #增加同步操作
        data_time.update(time.time() - end)

        # print(input.shape)
        # from matplotlib import pyplot as plt
        # for m in range(7):
        #     plt.imshow(input[0, 0, m, :, :])
        #     plt.colorbar()
        #     plt.show()
        #     plt.imshow(input[0, 1, m, :, :])
        #     plt.colorbar()
        #     plt.show()
        #     plt.imshow(input[0, 2, m, :, :])
        #     plt.colorbar()
        #     plt.show()
        #     plt.imshow(input[0, 3, m, :, :])
        #     plt.colorbar()
        #     plt.show()
        # raise RuntimeError

        if args.cuda:
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)


        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure metrics, eg., accuracy and record loss
        if "accuracy_percent" in metrics:
            predicted_accuracy, _ = calculate_accuracy_percent(output, target)
            accuracy.update(predicted_accuracy.item(), input.size(0))

        # Todo: add more metrics such as recall for special dataset

        losses.update(loss.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()   #增加同步操作
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        n_iter = epoch * len(train_loader) + i

        if i % args.log_interval == 0:
            #print("min: {:.2f} max: {:.2f}".format(input.min(), input.max()))
            process_rate = (100.0 * (epoch * len(train_loader) + i) / (epochs * len(train_loader)))
            logger.info(bacth_format(process_rate, i+1, args.batch_size, len(train_loader.dataset), losses, accuracy))

            """
            args.tb_writer.add_embedding(output.data,
                                         metadata=target.cpu().data.numpy(),
                                         label_img=input.data,
                                         global_step=n_iter)
            """

            # weight_conv_l1 = model

            # args.tb_writer.add_histogram('hist', array, n_iter)

            # for name, param in model.named_parameters():
            #     # if 'bn' not in name:
            #         # print("{:s}--{:s}".format(str(name), str(param.shape)))
            #     #    args.tb_writer.add_histogram(name, param, n_iter)
            #     # if 'conv1.weight' in name:
            #     # #     args.tb_writer.add_image('conv1_filter', rescale_per_image(param[:,0:3,:,:,:]), n_iter)
            #     # if 'module.features.0.conv.weight' in name:
            #     #     args.tb_writer.add_image('I3D_conv1_filter', rescale_per_image(param), n_iter)
            #     if "module.conv1.weight" in name:
            #         if param.shape[1] <=3:
            #             args.tb_writer.add_image('conv1_filter', rescale_per_image(param), n_iter)


        # csv batch logger
        train_batch_logger.log({
            'epoch': epoch + 1,
            'batch': i + 1,
            'iteration': epoch * len(train_loader) + i + 1,
            'loss': losses.avg,
            'accuracy': accuracy.val,
            'lr': optimizer.param_groups[0]['lr']
        })

        # tensorboard writer
        args.tb_writer.add_scalar('data/batch_loss', losses.avg, n_iter)
        args.tb_writer.add_scalar('data/batch_accuracy', accuracy.val, n_iter)
        args.tb_writer.add_scalar('data/batch_lr', optimizer.param_groups[0]['lr'], n_iter)

    # print the average loss and metrics
    logger.info(separator_line(dis_len="half"))
    logger.info("=> Training:  "
                "Elapse: {data_time.sum:.2f}/{sum_time.sum:.2f}s  "
                "Loss: {loss.avg:.4f}  "
                "Accuracy: {acc.avg:.2f}%".format(loss=losses,
                                                  data_time=data_time,
                                                  sum_time=batch_time,
                                                  acc=accuracy))
    return accuracy.avg, losses.avg
示例#7
0
def generate_protocol_for_NTU_RGBD(dataset_param, logger):
    """
    generate_protocol_for_MSRAction3D:
    :param dataset_param:
    :return:
    """
    data_Name = "NTU_RGB+D"
    dataset_dir = dataset_param["data_dir"]

    data_attr = dataset_param["data_attr"]
    data_type = dataset_param["data_type"]

    protocols = list(dataset_param["eval_protocols"].keys())

    # AssertionError for none protocols
    assert len(protocols) > 0, "evaluation protocol should be declared!"

    # Preparing new dict for json file
    data_param_dict = {}
    data_param_dict["data_Name"] = data_Name
    data_param_dict["data_Path"] = dataset_dir
    data_param_dict["file_format"] = {}
    data_param_dict["file_format"]["Depth"] = dataset_param["data_type"][
        "Depth"]
    data_param_dict["file_format"]["Skeleton"] = dataset_param["data_type"][
        "Skeleton"]
    data_param_dict["eval_Protocol"] = {}
    data_param_dict["action_Names"] = data_attr["n_names"]

    file_list_dir = dataset_param["file_list_dir"]

    if not os.path.exists(file_list_dir):
        os.makedirs(file_list_dir)

    for p_i, protocol_i in enumerate(protocols):

        protocol_param = dataset_param["eval_protocols"][protocol_i]
        train_list = []
        test_list = []

        action_list = []

        data_path = dataset_param["data_type"]["Depth"]

        file_list = os.listdir(data_path)

        for filename in file_list:

            action_class = int(filename[filename.find('A') +
                                        1:filename.find('A') + 4])
            subject_id = int(filename[filename.find('P') +
                                      1:filename.find('P') + 4])
            camera_id = int(filename[filename.find('C') +
                                     1:filename.find('C') + 4])

            if protocol_i == "cross_view":
                istraining = (camera_id in protocol_param["train_cam"])

            elif protocol_i == "cross_subjects":
                istraining = (subject_id in protocol_param["train_subs"])

            else:
                raise ValueError()

            img_path = data_path + "/" + filename

            img_count = 0

            for img_name in os.listdir(img_path):
                if ".png" in img_name:
                    img_count += 1

            #img_count = len(os.listdir(img_path))

            assert img_count > 0, ValueError("Empty folder!")

            action_list.append(action_class)

            line_str = "{file_str:s}\t{label:d}\t{frames:d}".format(
                file_str=filename, label=action_class - 1, frames=img_count)

            if istraining:
                train_list.append(line_str)
            else:
                test_list.append(line_str)

        # constant check for num classes
        num_classes = len(set(action_list))

        if num_classes != len(data_attr["n_names"]):
            logger.warn(
                "Warning: num classes: {:d} is not equal to class names: {}.".
                format(num_classes, len(data_attr["n_names"])))
        # protocol name
        protocol_item = "{:d}_{:s}".format(p_i + 1, protocol_i)

        logger.info(protocol_item + ": ")

        # Write train list to file
        train_list_file = protocol_param["file_list_train"].replace(
            "<$protocol_item>", protocol_item)
        with open(train_list_file, "w") as trlf:
            for train_line in train_list:
                trlf.write(train_line + "\n")
            trlf.close()
            logger.info("    Train filelist has been stored in '{:s}'".format(
                train_list_file))

        # Write test list to file
        test_list_file = protocol_param["file_list_test"].replace(
            "<$protocol_item>", protocol_item)
        with open(test_list_file, "w") as telf:
            for test_line in test_list:
                telf.write(test_line + "\n")
            telf.close()
            logger.info("    Test filelist has been stored in '{:s}'".format(
                test_list_file))

        logger.info(
            "    => Summary: {:d} samples for training and {:d} samples for test."
            .format(len(train_list), len(test_list)))
        logger.info("    => Number of classes: {:d}".format(num_classes))

        logger.info(separator_line(dis_len="half"))

        assert len(train_list) > 0 and len(
            test_list) > 0, "Target dataset has no samples to read."
        data_param_dict["eval_Protocol"][protocol_item] = {}
        data_param_dict["eval_Protocol"][protocol_item][
            "train"] = train_list_file
        data_param_dict["eval_Protocol"][protocol_item][
            "test"] = test_list_file

    # num_classes to eval_Protocol
    data_param_dict["num_classes"] = num_classes

    # write protocol param to json file
    data_param_dict_file = dataset_param["eval_config_file"].format(
        file_list_dir, data_Name)

    with open(data_param_dict_file, 'w') as jsf:
        json.dump(data_param_dict, jsf, indent=4)
        logger.info("Evaluation protocols have been stored in '{:s}'".format(
            data_param_dict_file))
        logger.info(separator_line())

    return data_param_dict_file
示例#8
0
def send_mail_notification(args, mail_config_file=None):
    if mail_config_file is None:
        mail_config_file = os.path.dirname(
            os.path.realpath(__file__)) + "/mail_config.json"

    with open(mail_config_file, "r") as mcf:
        mail_param = json.load(mcf)

        attachments_list = [
            attachment.format(log_name=args.log_name)
            for attachment in mail_param["attachments"]
        ]

        mail_subject = "[{code:s}] Result of {net_arch:s} on {data_name:s}".format(
            code=args.code, net_arch=args.net_arch, data_name=args.dataset)
        summary_lines = "Summary of running <b>{code:s}</b> "\
                        "with {net_arch:s} on {data_name:s}:".format(code=args.code,
                                                                        net_arch=args.net_arch,
                                                                        data_name=args.dataset)

        time_consume_line = "Time elapsed {:.2f} hours.".format(
            args.running_time.seconds / 3600.0)

        from_nickname = mail_param["from_nickname"].format(code=args.code)

        # read the last line of verbose logs as mail content text
        with open(attachments_list[0], 'r') as flog:
            lines = flog.readlines()
            last_line_in_logs = lines[-1]

        # tensor_board links
        links_html = ""
        for (text_str, herf_str) in mail_param["online_links"].items():
            herf_str = herf_str.format(log_name=args.log_name.split("/")[-1])
            link_html = """
                <p><a href="{herf_str:s}">{text_str:s}</a></p>
            """.format(herf_str=herf_str, text_str=text_str)
            links_html += link_html

        mail_content = mail_templete.format(
            summary_lines=summary_lines,
            last_line_in_logs=last_line_in_logs,
            time_consume_line=time_consume_line,
            links_html=links_html)

        try:
            # create a e-mail header
            message = MIMEMultipart()
            message['From'] = formataddr(
                (Header(from_nickname,
                        "utf-8").encode(), mail_param["mail_username"]))

            message['To'] = ",".join(mail_param["target_addresses"])

            message['Subject'] = Header(mail_subject, "utf-8")

            # mail content html
            message.attach(MIMEText(mail_content, "html", "utf-8"))

            valid_attachments = []
            # mail attachments
            for att_name in attachments_list:
                file_name = att_name.split("/")[-1]
                if not os.path.isfile(att_name):
                    stream_str = "Fail to read '{:s}'".format(file_name)
                    message.attach(
                        MIMEText(stream_str + "\r\n", "plain", "utf-8"))
                    print("=>" + stream_str)
                else:
                    valid_attachments.append(att_name)

            # message.attach(MIMEText("-" * 40 + "\r\n", "plain", "utf-8"))
            if len(valid_attachments) != len(attachments_list):
                print(separator_line())

            for att_name in valid_attachments:
                file_name = att_name.split("/")[-1]
                attachment = MIMEText(
                    open(att_name, "rb").read(), "base64", "utf-8")
                attachment["Content-Type"] = "application/octet-stream"
                attachment[
                    "Content-Disposition"] = "attachment; filename={:s}".format(
                        file_name)
                message.attach(attachment)

            # smtp handle
            server = smtplib.SMTP_SSL(mail_param["ssl_server"],
                                      mail_param["ssl_port"])
            server.login(mail_param["mail_username"],
                         mail_param["mail_password"])
            server.sendmail(mail_param["mail_username"],
                            mail_param["target_addresses"],
                            message.as_string())

            print("Email notification has been send successfully.")

        except Exception:
            print("Error: Fail to send Email.")

        print(separator_line())
示例#9
0
文件: main.py 项目: ZQSIAT/todo0318
    # ------- send notification  ------------
    send_mail_notification(args)

    if args.tb_writer is not None:
        args.tb_writer.close()
        # --create archive symbolic link
        tensorboard_root = config.get_path_param(
            args.dataset)["tensorboard_root"]
        tb_symbolic_link(args.log_name, tensorboard_root)


if __name__ == '__main__':
    print('hollow word!')
    exit()

    stream_pool = separator_line()

    # acquire argparse options
    parser = argparse.ArgumentParser()

    parser.add_argument('--option',
                        default="./options/com_cs_G3_base.json",
                        metavar='PATH',
                        type=str,
                        help="args option file path (default: None)")

    parser.add_argument('--resume',
                        default=None,
                        metavar='PATH',
                        type=str,
                        help="resume file path (default: None)")
示例#10
0
def constructing_model(args, config, logger):
    model_param = config.get_model_param()
    assert args.net_arch in model_param["net_arch"], "Unknown net_arch!"
    for metric in args.metrics:
        assert metric in model_param["metrics"], "Unknown metrics!"
    if args.adjust_lr is not None:
        assert args.adjust_lr in model_param["adjust_lr"], "Unknown adjust lr!"
    assert args.criterion in model_param["criterion"], "Unknown criterion!"
    assert args.optimizer in model_param["optimizer"], "Unknown optimizer!"

    # read the number of classes from previous generated protocol file
    num_classes = config.get_number_classes(args.dataset)

    # -------------------- net_arch --------------------------
    assert model_param["net_arch"][
        args.net_arch] is not None, "Invalid parameter!"

    net_arch_str = model_param["net_arch"][args.net_arch].split(".")

    arch_import_str = "from schemes.net_arch." + net_arch_str[
        0] + " import " + net_arch_str[1] + " as net_arch"

    exec(arch_import_str)

    # args.data_shape in format of [CxFxHxW] without batch_zise

    if len(args.data_shape) == 5:
        in_channel = args.data_shape[1]
        num_segments = args.data_shape[2]
    else:
        # in_channel = args.data_shape[1] #FxC
        # num_segments = args.data_shape[0]

        in_channel = args.data_shape[0]  #CxF
        num_segments = args.data_shape[1]
    #
    # print(args.data_shape)
    #
    # print(in_channel)
    # print(num_segments)
    # raise RuntimeError

    net_arch_kwargs = "(num_classes={:d}, in_channel={:d}, num_segments={:d})".format(
        num_classes, in_channel, num_segments)

    net_func_string = "net_arch" + net_arch_kwargs
    model = eval(net_func_string)

    # -------------------- metrics --------------------------
    metrics = args.metrics

    # -------------------- criterion --------------------------
    assert model_param["criterion"][
        args.criterion] is not None, "Invalid parameter!"

    criterion_func_string = model_param["criterion"][args.criterion] + "()"
    criterion = eval(criterion_func_string)

    if not args.evaluate:

        # -------------------- optimizer --------------------------
        opt_param = model_param["optimizer"][args.optimizer]

        if "SGD" in args.optimizer:
            optimizer = optim.SGD(model.parameters(),
                                  lr=opt_param["lr"],
                                  momentum=opt_param["momentum"],
                                  weight_decay=opt_param["weight_decay"])
        elif "Adam" in args.optimizer:
            optimizer = optim.Adam(model.parameters(),
                                   lr=opt_param["lr"],
                                   weight_decay=opt_param["weight_decay"])
        else:
            raise NotImplementedError("Unknown optimizer!")
    else:
        optimizer = None

    if args.pretrained is not None:
        pretrained_model = torch.load(args.pretrained)
        pretrained_dict = pretrained_model["state_dict"]
        model_dict = model.state_dict()

        pretrained_dict_new = {}
        for k, v in pretrained_dict.items():
            kn = k.replace("module.", "")
            if kn in model_dict and "fc" not in kn:
                pretrained_dict_new[kn] = v

        model_dict.update(pretrained_dict_new)
        model.load_state_dict(model_dict)

        logger.info("=> using pre-trained model '{}' from '{:s}'".format(
            args.net_arch, args.pretrained))
        # raise RuntimeError
        # # freezing layers except the fc
        # for name, param in model.named_parameters():
        #     if "fc" in name:
        #         continue
        #     if "conv1" in name:
        #         continue
        #     param.requires_grad = False
        #     logger.info("--> freezing '{:s}'".format(name))

    # ------------------ plot and save model arch
    if args.plot_net_arch:
        input_size = args.data_shape
        # file_name = net_arch_file_name.format(args.log_name)
        # plot_net_architecture(model, input_size, file_name)

        dummy_input = torch.randn(tuple(input_size))
        dummy_input.unsqueeze_(0)
        logger.info("dummy input shape: {:s}".format(str(dummy_input.shape)))

        # if args.tensorboard:
        #     args.tb_writer.add_graph(model, (dummy_input,))

    # ----------------------- cuda transfer -----------------------
    if args.cuda:
        if len(args.gpu_card.split(",")) > 1:
            logger.info("=> Using mult GPU")
            model = torch.nn.DataParallel(model).cuda()
        else:
            logger.info("=> Using single GPU")
            model = model.cuda()

        criterion = criterion.cuda()
        cudnn.benchmark = True
        # cudnn.enabled = True
        # cudnn.deterministic = True

    # -----------------------------------------------------------------
    logger.info("model summary: ")

    logger.info(separator_line(dis_len="half"))
    logger.info("net_arch is: ")
    logger.info(model)

    logger.info(separator_line(dis_len="half"))
    net_parameter_statistics(logger, model.named_parameters())

    logger.info(separator_line(dis_len="half"))
    logger.info("metrics is: ")
    logger.info(metrics)

    logger.info(separator_line(dis_len="half"))
    logger.info("criterion is: ")
    logger.info(criterion)

    logger.info(separator_line(dis_len="half"))
    logger.info("optimizer is: ")
    logger.info(optimizer)
    logger.info(separator_line())

    return [model, metrics, criterion, optimizer]
示例#11
0
def test_model(args, config, logger, val_loader, model_combined):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracy = AverageMeter()
    samples_count = AverageMeter()
    samples_right = AverageMeter()

    [model, metrics, criterion, _] = model_combined

    logger.info("=> Recovery model dict state from checkpoint")
    if os.path.isfile(args.resume):
        logger.info("{:s} loading checkpoint '{}'".format(
            datetime_now_string(), args.resume))
        checkpoint = torch.load(args.resume)

        # load param from checkpoint file
        best_prec = checkpoint['best_prec']
        arch = checkpoint['arch']
        state_dict = checkpoint['state_dict']
        assert arch == args.net_arch, "The arch of checkpoint is not consistent with current model!"

        model.load_state_dict(state_dict)

        logger.info("{:s} loaded checkpoint '{}' (epoch {})".format(
            datetime_now_string(), args.resume, checkpoint['epoch']))

    else:
        logger.error("Error: No checkpoint found at '{}'".format(args.resume))
        raise RuntimeError

    # get checkpoint saved directory
    evaluate_root = config.get_path_param(args.dataset)["evaluate_root"]
    ensure_directory(evaluate_root)

    logger.info(separator_line())
    # switch to evaluate mode
    model.eval()
    logger.info("{:s} Start model evaluation".format(datetime_now_string()))
    logger.info(separator_line())

    with torch.no_grad():
        end = time.time()
        y_true = []
        y_pred = []
        output_list = []

        for i, (input, target, _) in enumerate(val_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            if args.cuda is not None:
                input = input.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            output_list.append(output)

            y_true.extend(target.tolist())
            _, pred = output.topk(1, 1, True, True)
            y_pred.extend(pred.t().tolist()[0])

            # measure accuracy and record loss
            if "accuracy_percent" in metrics:
                predicted_accuracy, n_correct_elems = calculate_accuracy_percent(
                    output, target)
                samples_count.update(input.size(0))
                samples_right.update(n_correct_elems.item())
                accuracy.update(predicted_accuracy.item(), input.size(0))

            # Todo: add more metrics such as recall for special dataset

            losses.update(loss.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_interval == 0:
                logger.info(
                    '{:4.1f}% evaluate:[{:05d}/{:05d}] '
                    'Loss: {loss.val:7.4f} ({loss.avg:7.4f})  '
                    'Accuracy: {acc.val:6.2f} ({acc.avg:6.2f})  '.format(
                        100.0 * i / len(val_loader),
                        i * args.batch_size,
                        len(val_loader.dataset),
                        loss=losses,
                        acc=accuracy))

        output_score = F.softmax(torch.cat(output_list, dim=0))
        output_file_path = "{:s}/{:s}_{:.2f}_{:s}_{:s}".format(
            evaluate_root, args.eval_protocol, accuracy.avg, args.code,
            args.net_arch)
        save_predict_result(
            logger, {
                'output_score': output_score,
                'true_label': y_true,
                'pred_label': y_pred
            }, output_file_path)

        if args.plot_confusion_matrix:
            class_names = val_loader.class_names
            plt_fig = generate_confusion_matrix(y_true, y_pred, class_names)
            args.tb_writer.add_figure('confusion matrix', plt_fig)

        logger.info('=> Evaluate:  '
                    'Elapse: {data_time.sum:.2f}/{sum_time.sum:.2f}s  '
                    'Loss: {loss.avg:.4f}  '
                    'Model record accuracy: {best_acc:.2f}% '
                    'Evaluate accuracy: {acc.avg:.2f}% '
                    '[{right:.0f}/{count:.0f}]'.format(
                        loss=losses,
                        data_time=data_time,
                        sum_time=batch_time,
                        best_acc=best_prec,
                        acc=accuracy,
                        right=samples_right.sum,
                        count=samples_count.sum))
        logger.info(separator_line())

    return accuracy.avg, losses.avg