Пример #1
0
def load_checkpoint(name):

    check = torch.load("../models/" + model, map_location="cuda")
    loss = check["val_best_loss"]

    model = se_resnet34(num_classes=2, multi_output=True).to(device)
    model = densenet121(if_selayer=True).to(device)

    model.load_state_dict(check["model"])
    model = model.to(device)

    model.eval()

    return model
Пример #2
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(
        img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(
        json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    # create model
    model = densenet121(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./weights/model-3.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(
        class_indict[str(predict_cla)], predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()
Пример #3
0
def main():
    logger, result_dir, _ = utils.config_backup_get_log(args, __file__)

    device = utils.get_device()
    utils.set_seed(args.seed, device)  # set random seed

    dataset = COVID19DataSet(root=args.datapath,
                             ctonly=args.ctonly)  # load dataset
    trainset, testset = split_dataset(dataset=dataset, logger=logger)

    if args.model.lower() in ['mobilenet']:
        net = mobilenet_v2(task='classification',
                           moco=False,
                           ctonly=args.ctonly).to(device)
    elif args.model.lower() in ['densenet']:
        net = densenet121(task='classification',
                          moco=False,
                          ctonly=args.ctonly).to(device)
    else:
        raise Exception

    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=25,
                                                gamma=0.1)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.bstrain,
                                              shuffle=True,
                                              num_workers=args.nworkers)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.bstest,
                                             shuffle=False,
                                             num_workers=args.nworkers)

    best_auroc = 0.
    print('==> Start training ..')
    start = time.time()
    for epoch in range(args.maxepoch):
        net = train(epoch, net, trainloader, criterion, optimizer, scheduler,
                    args.model, device)
        scheduler.step()
        if epoch % 5 == 0:
            auroc, aupr, f1_score, accuracy = validate(net, testloader, device)
            logger.write(
                'Epoch:%3d | AUROC: %5.4f | AUPR: %5.4f | F1_Score: %5.4f | Accuracy: %5.4f\n'
                % (epoch, auroc, aupr, f1_score, accuracy))
            if auroc > best_auroc:
                best_auroc = auroc
                best_aupr = aupr
                best_epoch = epoch
                print("save checkpoint...")
                torch.save(net.state_dict(),
                           './%s/%s.pth' % (result_dir, args.model))

    auroc, aupr, f1_score, accuracy = validate(net, testloader, device)
    logger.write(
        'Epoch:%3d | AUROC: %5.4f | AUPR: %5.4f | F1_Score: %5.4f | Accuracy: %5.4f\n'
        % (epoch, auroc, aupr, f1_score, accuracy))

    if args.batchout:
        with open('temp_result.txt', 'w') as f:
            f.write("%10.8f\n" % (best_auroc))
            f.write("%10.8f\n" % (best_aupr))
            f.write("%d" % (best_epoch))

    end = time.time()
    hours, rem = divmod(end - start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    logger.write("Elapsed Time: {:0>2}:{:0>2}:{:05.2f}\n".format(
        int(hours), int(minutes), seconds))
    return True
Пример #4
0
    # model = se_resnet34(num_classes=2).to(device)
    model = se_resnet34(num_classes=2, multi_output=True).to(device)
elif args.model == "bengali_resnet34":
    model = model_bengali.se_resnet34(num_classes=2,
                                      multi_output=True).to(device)
elif args.model == "bengali_resnext50":
    model = model_bengali.se_resnext50_32x4d(num_classes=2,
                                             multi_output=True).to(device)
elif args.model == "resnet152":
    model = se_resnet152(num_classes=2, multi_output=True).to(device)
elif args.model == "resnext50":
    model = se_resnext50_32x4d(num_classes=2, multi_output=True).to(device)
elif args.model == "resnext101":
    model = se_resnext101_32x8d(num_classes=2, multi_output=True).to(device)
elif args.model == "densenet":
    model = densenet121(if_selayer=True).to(device)
elif args.model == "inception_v3":
    model = torchvision.models.inception_v3(pretrained=False,
                                            num_classes=11 + 168 +
                                            7).to(device)
elif args.model in efficientnets:
    from efficientnet_pytorch import EfficientNet
    efficientnet_name = args.model
    print("efficientnet:", efficientnet_name)

    model = EfficientNet.from_pretrained(args.model,
                                         num_classes=11 + 168 + 7).to(device)
else:
    raise ValueError()

# train_all = load_train_df()
Пример #5
0
def train(args):
    batch_size = args.batch_size
    epoch = args.epoch
    network = args.network
    opt = args.opt
    train = unpickle(args.train_path)
    test = unpickle(args.test_path)
    train_data = train[b'data']
    test_data = test[b'data']

    x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    x_train = x_train.transpose(0, 2, 3, 1)
    y_train = train[b'fine_labels']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_train = norm_images(x_train)
    x_test = norm_images(x_test)

    print('-------------------------------')
    print('--train/test len: ', len(train_data), len(test_data))
    print('--x_train norm: ', compute_mean_var(x_train))
    print('--x_test norm: ', compute_mean_var(x_test))
    print('--batch_size: ', batch_size)
    print('--epoch: ', epoch)
    print('--network: ', network)
    print('--opt: ', opt)
    print('-------------------------------')

    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_train, y_train, './trans/', 'tran.tfrecords')
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')

    dataset = tf.data.TFRecordDataset('./trans/tran.tfrecords')
    dataset = dataset.map(parse_function)
    dataset = dataset.shuffle(buffer_size=50000)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [None, ])
    y_input_one_hot = tf.one_hot(y_input, 100)
    lr = tf.placeholder(tf.float32, [])

    if network == 'resnet50':
        prob = resnet50(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet34':
        prob = resnet34(x_input, is_training=True, reuse=False,
                        kernel_initializer=tf.contrib.layers.variance_scaling_initializer())
    elif network == 'resnet18':
        prob = resnet18(x_input, is_training=True, reuse=False,
                        kernel_initializer=tf.contrib.layers.variance_scaling_initializer())
    elif network == 'seresnet50':
        prob = se_resnet50(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet110':
        prob = resnet110(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet110':
        prob = se_resnet110(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet152':
        prob = se_resnet152(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet152':
        prob = resnet152(x_input, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet_fixed':
        prob = get_resnet(x_input, 152, trainable=True, w_init=tf.orthogonal_initializer())
    elif network == 'densenet121':
        prob = densenet121(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet169':
        prob = densenet169(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet201':
        prob = densenet201(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet161':
        prob = densenet161(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet100bc':
        prob = densenet100bc(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet190bc':
        prob = densenet190bc(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext50':
        prob = resnext50(x_input, reuse=False, is_training=True, cardinality=32,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext110':
        prob = resnext110(x_input, reuse=False, is_training=True, cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext152':
        prob = resnext152(x_input, reuse=False, is_training=True, cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext50':
        prob = se_resnext50(x_input, reuse=False, is_training=True, cardinality=32,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext110':
        prob = se_resnext110(x_input, reuse=False, is_training=True, cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext152':
        prob = se_resnext152(x_input, reuse=False, is_training=True, cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prob, labels=y_input_one_hot))

    conv_var = [var for var in tf.trainable_variables() if 'conv' in var.name]
    l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in conv_var])
    loss = l2_loss * 5e-4 + loss

    if opt == 'adam':
        opt = tf.train.AdamOptimizer(lr)
    elif opt == 'momentum':
        opt = tf.train.MomentumOptimizer(lr, 0.9)
    elif opt == 'nesterov':
        opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = opt.minimize(loss)

    logit_softmax = tf.nn.softmax(prob)
    acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logit_softmax, 1), y_input), tf.float32))

    # -------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input, 152, type='se_ir', trainable=False, reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input, reuse=True, is_training=False, kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input, reuse=True, is_training=False, kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None)

    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input), tf.float32))
    # ----------------------------------------------------------------------------
    saver = tf.train.Saver(max_to_keep=1, var_list=tf.global_variables())
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    now_lr = 0.001  # Warm Up
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        counter = 0
        max_test_acc = -1
        for i in range(epoch):
            sess.run(iterator.initializer)
            while True:
                try:
                    batch_train, label_train = sess.run(next_element)
                    _, loss_val, acc_val, lr_val = sess.run([train_op, loss, acc, lr],
                                                            feed_dict={x_input: batch_train, y_input: label_train,
                                                                       lr: now_lr})

                    counter += 1

                    if counter % 100 == 0:
                        print('counter: ', counter, 'loss_val', loss_val, 'acc: ', acc_val)
                    if counter % 1000 == 0:
                        print('start test ')
                        sess.run(iterator_test.initializer)
                        avg_acc = []
                        while True:
                            try:
                                batch_test, label_test = sess.run(next_element_test)
                                acc_test_val = sess.run(acc_test, feed_dict={x_input: batch_test, y_input: label_test})
                                avg_acc.append(acc_test_val)
                            except tf.errors.OutOfRangeError:
                                print('end test ', np.sum(avg_acc) / len(y_test))
                                now_test_acc = np.sum(avg_acc) / len(y_test)
                                if now_test_acc > max_test_acc:
                                    print('***** Max test changed: ', now_test_acc)
                                    max_test_acc = now_test_acc
                                    filename = 'params/distinct/' + network + '_{}.ckpt'.format(counter)
                                    # saver.save(sess, filename)
                                break
                except tf.errors.OutOfRangeError:
                    print('end epoch %d/%d , lr: %f' % (i, epoch, lr_val))
                    now_lr = lr_schedule(i, args.epoch)
                    break
Пример #6
0
def test(args):
    # train = unpickle('/data/ChuyuanXiong/up/cifar-100-python/train')
    # train_data = train[b'data']
    # x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    # x_train = x_train.transpose(0, 2, 3, 1)

    test = unpickle(args.test_path)
    test_data = test[b'data']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_test = norm_images(x_test)
    # x_test = norm_images_using_mean_var(x_test, *compute_mean_var(x_train))

    network = args.network
    ckpt = args.ckpt

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [None, ])
    # -------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/test.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input, 152, type='se_ir', trainable=False, reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input, reuse=False, is_training=False, kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input, reuse=False, is_training=False, kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None)

    # prob_test = tf.layers.dense(prob_test, 100, reuse=True, name='before_softmax')
    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input), tf.float32))

    var_list = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars

    saver = tf.train.Saver(var_list=var_list)
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        saver.restore(sess, ckpt)
        sess.run(iterator_test.initializer)
        avg_acc = []
        while True:
            try:
                batch_test, label_test = sess.run(next_element_test)
                acc_test_val = sess.run(acc_test, feed_dict={x_input: batch_test, y_input: label_test})
                avg_acc.append(acc_test_val)
            except tf.errors.OutOfRangeError:
                print('end test ', np.sum(avg_acc) / len(y_test))
                break
def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print(
        'Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/'
    )
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(
        args.data_path)

    data_transform = {
        "train":
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "val":
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0,
              8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=nw,
        collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在预训练权重则载入
    model = densenet121(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            load_state_dict(model, args.weights)
        else:
            raise FileNotFoundError("not found weights file: {}".format(
                args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外,其他权重全部冻结
            if "classifier" not in name:
                para.requires_grad_(False)

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg,
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=1E-4,
                          nesterov=True)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (
        1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        # validate
        acc = evaluate(model=model, data_loader=val_loader, device=device)

        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["loss", "accuracy", "learning_rate"]
        tb_writer.add_scalar(tags[0], mean_loss, epoch)
        tb_writer.add_scalar(tags[1], acc, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
# model selection
if args.model == "resnet18":
    print("SeNet18")
    model = se_resnet18(2,
                        if_mixup=args.mixup,
                        if_shake_shake=args.shake,
                        first_conv_stride=args.first_stride,
                        first_pool=True).to(device)
elif args.model == "resnet152":
    print("SeNet152")
    model = se_resnet152(2, if_mixup=args.mixup,
                         if_shake_shake=args.shake).to(device)
elif args.model == "densenet121":
    print("DenseNet121")
    model = densenet121(if_mixup=args.mixup,
                        if_selayer=args.se,
                        first_conv_stride=args.first_stride,
                        first_pool=True).to(device)
elif args.model == "densenet169":
    print("DenseNet169")
    model = densenet169(if_mixup=args.mixup, if_selayer=args.se).to(device)
elif args.model == "densenet201":
    print("DenseNet201")
    model = densenet201(if_mixup=args.mixup, if_selayer=args.se).to(device)
elif args.model == "dpn92":
    print("DPN92")
    model = dpn92(num_classes=2, if_selayer=args.se).to(device)
elif args.model == "dpn98":
    print("DPN98")
    model = dpn98(num_classes=2, if_selayer=args.se).to(device)
elif args.model == "dpn131":
    print("DPN131")
def train(X_train,
          Y_train,
          X_val,
          Y_val,
          unlabeled_data=None,
          batch=16,
          gpu_id=0,
          epoch=100):
    """
    train model from training and validation data.
    Args :
        X_train : Training data
        Y_train : Labels of training data
        X_val : Validation data
        Y_val : Labels of validation data
        unlabeled_data : Training data which don't have labels (This is for semisupervised learning.)
        batch : batchsize
        gpu_id : GPU id where model is trained at
    """

    assert len(X_train) == len(
        Y_train
    ), "training data and its labels must have the same length : {}, {}".format(
        len(X_train), len(Y_train))
    assert len(X_val) == len(
        Y_val
    ), "validation data and its labels must have the same length : {}, {}".format(
        len(X_val), len(Y_val))
    if unlabeled_data is not None:
        assert X_train.shape[1:] == X_val.shape[1:] == unlabeled_data.shape[
            1:], "All data must have the same shape"

    H, W = X_train[0].shape[:2]
    print("img shape: {}*{}".format(H, W))

    BATCH_SIZE = batch
    EPOCH = epoch
    gpu_name = "cuda:" + str(gpu_id)
    device = torch.device(gpu_name if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.benchmark = True

    print("=============training config=============")
    print("device:{}".format(device))
    print("MODEL:", args.model)
    print("save dir : ", model_save_dir)
    print("batch size:{}".format(BATCH_SIZE))
    if args.mixup:
        print("Manifold-mixup:{}".format(args.mixup))
    if args.shake:
        print("Shake-shake regularization:{}".format(args.shake))
    if args.vat:
        print("VAT regularization:{}".format(args.vat))
    print("========================================")

    # Calculate the average of pixels per channel
    # def avg_each_channel(path):
    #     img = Image.open(path)
    #     if not len(np.array(img)) == 3:
    #         img = img.convert("RGB")
    #     img = np.asarray(img)
    #     ch_mean = np.average(np.average(img, axis=0), axis=0)
    #     return ch_mean # (3, )
    #
    # ch_means_per_image = Parallel(n_jobs=-1, verbose=-1)([delayed(avg_each_channel)(path) for path in path_dev])
    # ch_mean = np.average(ch_means_per_image, axis=0)
    # print("channel mean:{}".format(ch_mean))
    #
    # def std_each_channel(path):
    #     channel_mean = [237.78516378, 237.16343756, 237.0501237]
    #
    #     img = Image.open(path)
    #     if not len(np.array(img)) == 3:
    #         img = img.convert("RGB")
    #     img = np.asarray(img)
    #     R, G, B = img[:, :, 0], img[:, :, 1], img[:, :, 2]
    #     R_flat = R.flatten()
    #     G_flat = G.flatten()
    #     B_flat = B.flatten()
    #
    #     R_diff = np.sum(np.square(R_flat - channel_mean[0]))
    #     G_diff = np.sum(np.square(G_flat - channel_mean[1]))
    #     B_diff = np.sum(np.square(B_flat - channel_mean[2]))
    #
    #     return (R_diff, G_diff, B_diff)
    # pixels_diff = Parallel(n_jobs=-1, verbose=-1)([delayed(std_each_channel)(path) for path in path_dev])
    # R_all, G_all, B_all = 0, 0, 0
    # for pixel_diff in pixels_diff:
    #     R_all += pixel_diff[0]
    #     G_all += pixel_diff[1]
    #     B_all += pixel_diff[2]
    #
    # R_std = np.sqrt(R_all/(H*W*len(path_dev)))
    # G_std = np.sqrt(G_all/(H*W*len(path_dev)))
    # B_std = np.sqrt(B_all/(H*W*len(path_dev)))
    #
    # ch_std = [R_std, G_std, B_std]
    # print("channel std:", ch_std)

    ch_mean = [237.78516378, 237.16343756, 237.0501237]
    ch_std = [146.47616225347468, 146.6169214951974, 145.59586636818233]

    c_train = Counter(Y_train)
    c_val = Counter(Y_val)
    c_test = Counter(Y_test)
    print("train:{}, val:{}, test:{}".format(c_train, c_val, c_test))
    print(
        "train data length:{}, validation data length:{}, test data length:{}".
        format(len(X_train), len(X_val), len(X_test)))

    if args.semisupervised:
        train = MeanTeacherDataset(
            X_train,
            Y_train,
            transform=transforms.Compose([
                #ColorJitter(brightness=0.0, contrast=0.4, hue=0.0),
                RandomRotate(hard_rotate=False, angle=5),
                Regularizer(ch_mean, ch_std=ch_std),
                #Samplewise_Regularizer(),
                ToTensor()
            ]))

    else:
        train = Dataset(
            X_train,
            Y_train,
            transform=transforms.Compose([  #UnsharpMasker(radius=5.0),
                #ColorJitter(brightness=0.0, contrast=0.4, hue=0.0),
                #RandomScalor(scale_range=(301, 330), crop_size=H),
                RandomRotate(hard_rotate=False, angle=5),
                Regularizer(ch_mean, ch_std=ch_std),
                #Samplewise_Regularizer(),
                ToTensor()
            ]))
    val = Dataset(
        X_val,
        Y_val,
        transform=transforms.Compose([  #UnsharpMasker(radius=5.0),
            Regularizer(ch_mean, ch_std=ch_std),
            #Samplewise_Regularizer(),
            ToTensor()
        ]))

    dataset_sizes = {"train": train.__len__(), "val": val.__len__()}
    print("dataset size:", dataset_sizes)

    val_batch = BATCH_SIZE

    if args.pseudo_label:
        # Class ratio may be unbalance
        class_sample_count = np.array(
            [len(np.where(Y_train == t)[0]) for t in np.unique(Y_train)])
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[t] for t in Y_train])
        samples_weight = torch.from_numpy(samples_weight)
        samples_weight = samples_weight.double()
        sampler = torch.utils.data.sampler.WeightedRandomSampler(
            samples_weight, len(samples_weight))
        train_iter = torch.utils.data.DataLoader(train,
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=False,
                                                 sampler=sampler)
    else:
        train_iter = torch.utils.data.DataLoader(train,
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=True)
    val_iter = torch.utils.data.DataLoader(val,
                                           batch_size=val_batch,
                                           shuffle=False)

    # model selection
    if args.model == "resnet18":
        print("SeNet18")
        model = se_resnet18(2,
                            if_mixup=args.mixup,
                            if_shake_shake=args.shake,
                            first_conv_stride=2,
                            first_pool=True).to(device)
    elif args.model == "resnet50":
        print("SeNet50")
        model = se_resnet50(2,
                            if_mixup=args.mixup,
                            if_shake_shake=args.shake,
                            first_conv_stride=2,
                            first_pool=True).to(device)
    elif args.model == "resnet101":
        print("SeNet101")
        model = se_resnet101(2,
                             if_mixup=args.mixup,
                             if_shake_shake=args.shake,
                             first_conv_stride=2,
                             first_pool=True).to(device)
    elif args.model == "resnet152":
        print("SeNet152")
        model = se_resnet152(2, if_mixup=args.mixup,
                             if_shake_shake=args.shake).to(device)
    elif args.model == "densenet121":
        print("DenseNet121")
        model = densenet121(if_mixup=args.mixup,
                            if_selayer=args.se,
                            first_conv_stride=2,
                            first_pool=True,
                            drop_rate=args.drop_rate).to(device)

    elif args.model == "dpn92":
        print("DPN92")
        model = dpn92(num_classes=2,
                      if_selayer=args.se,
                      if_mixup=args.mixup,
                      first_conv_stride=2,
                      first_pool=True).to(device)
    elif args.model == "dpn98":
        print("DPN98")
        model = dpn98(num_classes=2, if_selayer=args.se,
                      if_mixup=args.mixup).to(device)

    else:
        print("WRONG MODEL NAME")
        input("---------Stop-----------")

    if args.semisupervised:
        """ Declare the teacher model """
        def sigmoid_rampup(current, rampup_length):
            """ Exponential rampup from https://arxiv.org/abs/1610.02242 """
            if rampup_length == 0:
                return 1.0
            else:
                current = np.clip(current, 0.0, rampup_length)
                phase = 1.0 - current / rampup_length
                return float(np.exp(-5.0 * phase * phase))

        def get_current_consistency_weight(epoch):
            # Consistency ramp-up from https://arxiv.org/abs/1610.02242
            return args.consistency * sigmoid_rampup(
                epoch, rampup_length=int(args.epoch / 2))

        def update_teacher(student, teacher, alpha, global_step):
            """
            update parameters of the teacher model.
            Args :
                student : Current model to train
                teacher : Current teacher model
                alpha : A parameter of models mixing weights
                global_step : Global step of training
            """
            alpha = min(1 - 1 / (global_step + 1), alpha)
            for teacher_param, param in zip(teacher.parameters(),
                                            student.parameters()):
                teacher_param.data.mul_(alpha).add_(1 - alpha, param.data)

        teacher_model = copy.deepcopy(model)
        consistency_criterion = nn.MSELoss()

    print("model preparing...")
    loss_function = nn.CrossEntropyLoss(ignore_index=-1)
    bce_loss = torch.nn.BCEWithLogitsLoss()
    softmax = torch.nn.Softmax(dim=1)

    #Set optimizer
    init_learning_rate = args.learning_rate
    optimizer = optim.SGD(model.parameters(),
                          lr=init_learning_rate,
                          momentum=0.9,
                          nesterov=True,
                          dampening=0,
                          weight_decay=0.0005)
    # optimizer = torch.optim.Adam(model.parameters(), lr=init_learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[int(EPOCH * 0.5),
                               int(EPOCH * 0.75)], gamma=0.1)

    lowest_loss = 1000000000000

    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    PATH = ""
    global_step = 0
    start_epoch = 0
    Logger = Train_Logger()

    if args.resume:
        # resume training from the latest checkpoint

        checkpoint_names = os.listdir(os.path.join("model", args.directory))

        # find a checkpoint having lowest loss
        min_loss = 1000000000000
        for name in checkpoint_names:
            loss = float(name.split("_")[0][7:])
            if loss < min_loss:
                model_name = name
                min_loss = loss

        MODEL_PATH = os.path.join("/home/tanaka301052/tegaki/model",
                                  args.directory, model_name)

        # load model and optimizer from the checkpoint
        checkpoint = torch.load(MODEL_PATH, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        lowest_loss = checkpoint['lowest_loss']
        global_step = checkpoint["global_step"]

        print("==>" + "model loaded:", model_name)
        print("current epoch:", start_epoch)
        print("lowest loss:", lowest_loss)
        print("resuming...")

    for epoch in range(start_epoch, EPOCH, 1):
        scheduler.step()

        # Training Phase
        model.train()
        train_loss = 0
        train_corrects = 0
        loss = 0

        for i, train_data in enumerate(tqdm(train_iter)):
            global_step += 1
            if args.semisupervised:
                # Get inputs for both student model and teacher model
                samples1, samples2 = train_data
                student_inputs, labels = samples1
                teacher_inputs, labels = samples2

                student_inputs = student_inputs.to(device)
                teacher_inputs = teacher_inputs.to(device)
                labels = labels.to(device)

                # forwarding student
                student_outputs = model(student_inputs)
                _, preds = torch.max(student_outputs, 1)

                # forwarding teacher
                teacher_outputs = teacher_model(teacher_inputs).detach()

                # classification loss for student
                classification_loss = loss_function(student_outputs, labels)
                # consistency loss between student and teacher
                consistency_loss = consistency_criterion(
                    softmax(student_outputs), softmax(teacher_outputs))
                # get weight of consistency loss
                consistency_weight = get_current_consistency_weight(epoch)

                # total loss
                loss = classification_loss + consistency_weight * consistency_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if epoch < int(args.epoch / 2):
                    # during ramp up.
                    alpha = 0.99
                else:
                    alpha = 0.999
                update_teacher(model,
                               teacher_model,
                               alpha=alpha,
                               global_step=global_step)

                train_loss += classification_loss.item() * student_inputs.size(
                    0)
                train_corrects += (preds == labels).sum().item()

            else:
                inputs, labels = train_data
                inputs = inputs.to(device)
                labels = labels.to(device)
                if not args.mixup:
                    outputs = model(inputs)
                    loss = loss_function(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    train_corrects += (preds == labels).sum().item()
                elif args.mixup:
                    lam = sample_lambda_from_beta_distribution(
                        alpha=args.mixup_alpha)
                    lam = torch.from_numpy(np.array(
                        [lam]).astype('float32')).to(device)
                    output, reweighted_target = model(inputs,
                                                      lam=lam,
                                                      target=labels,
                                                      device=device)

                    loss = bce_loss(output, reweighted_target)

                train_loss += loss.item() * inputs.size(0)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        train_loss /= dataset_sizes["train"]
        train_losses.append(train_loss)
        train_acc = train_corrects / dataset_sizes["train"]
        train_accs.append(train_acc)
        print("=====> train loss:{:5f} Acc:{:5f}".format(
            train_loss, train_acc))

        # Validation Phase
        model.eval()
        val_loss = 0
        loss = 0
        val_corrects = 0
        for i, val_data in enumerate(val_iter):

            inputs, labels = val_data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            max_index = outputs.max(dim=1)[1]

            loss = loss_function(outputs, labels)
            val_corrects += (max_index == labels).sum().item()

            val_loss += loss.item() * inputs.size(0)

        val_loss /= dataset_sizes["val"]
        val_losses.append(val_loss)

        result = {
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "lowest_loss": lowest_loss,
            "val_accuracy": val_corrects / dataset_sizes["val"],
        }

        val_accs.append(result["val_accuracy"])

        if val_loss < lowest_loss:
            lowest_loss = val_loss
            if epoch >= 0:
                # Prevent saving the model in the first epoch.

                # remove latest saved model
                if os.path.isfile(PATH):
                    os.remove(PATH)

                model_name = "valloss{:.5f}_epoch{}.model".format(
                    val_loss, epoch)
                if not os.path.exists(model_save_dir):
                    os.makedirs(model_save_dir)
                PATH = os.path.join(model_save_dir, model_name)
                checkpoint = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "model_state_dict": model.state_dict(),
                    "optim_state_dict": optimizer.state_dict(),
                    "lowest_loss": val_loss
                }

                torch.save(checkpoint, PATH)
                line(
                    "Project:{} \nEPOCH:{} train loss:{:4f}, val Accuracy:{:4f}, val loss:{:4f}, best loss:{:4f}"
                    .format(args.directory, epoch, train_loss,
                            result["val_accuracy"], result["val_loss"],
                            lowest_loss))

        print("EPOCH:{} val Accuracy:{:4f}, val loss:{:4f}, best loss:{:4f}".
              format(epoch, result["val_accuracy"], result["val_loss"],
                     lowest_loss))
        Logger.send_loss_img(str(args.directory),
                             train_losses,
                             val_losses,
                             process_name=args.directory)
        Logger.send_accuracy_img(str(args.directory),
                                 train_accs,
                                 val_accs,
                                 process_name=args.directory)

        output_eval_file = os.path.join(
            "result", "{}_training_log.txt".format(args.directory))
        Logger.write_log(result, output_eval_file)

        Logger.save_history(train_losses,
                            "save/{}_train_losses.pkl".format(args.directory))
        Logger.save_history(val_losses,
                            "save/{}_val_losses.pkl".format(args.directory))
Пример #10
0
    img_resized_array = np.array(img_resized, dtype="uint8")
    return (img_resized_array, image_name)


image_information = Parallel(n_jobs=-1, verbose=-1)([
    delayed(open_image)(name, args.images_path, args.image_size)
    for name in image_names
])
image_holder, imagename_holder = [], []

for (image_array, image_name) in image_information:
    image_holder.append(image_array)
    imagename_holder.append(image_name)

model = densenet121(if_mixup=False,
                    if_selayer=True,
                    first_conv_stride=2,
                    first_pool=True).to(device)

checkpoint = torch.load(args.model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
print("==>" + "model loaded:", args.model_path)

ch_mean = [237.78516378, 237.16343756, 237.0501237]
ch_std = [146.47616225347468, 146.6169214951974, 145.59586636818233]

# dummy labels
labels = [0] * len(image_holder)

test = Dataset(image_holder,
               labels,
Пример #11
0
def train(epochs):
    """
    训练模型
    :param epochs:
    :return:
    """

    vgg = vgg16((None, 224, 224, 3), 102)
    resnet = resnet50((None, 224, 224, 3), 102)
    densenet = densenet121((None, 224, 224, 3), 102)
    models = [vgg, resnet, densenet]
    train_db, valid_db = load_db(32)
    his = []
    for model in models:
        variables = model.trainable_variables
        optimizers = tf.keras.optimizers.Adam(1e-4)
        for epoch in range(epochs):
            # training
            total_num = 0
            total_correct = 0
            training_loss = 0
            for step, (x, y) in enumerate(train_db):
                with tf.GradientTape() as tape:
                    # train
                    out = model(x)
                    loss = tf.losses.categorical_crossentropy(
                        y, out, from_logits=False)
                    loss = tf.reduce_mean(loss)
                    training_loss += loss
                    grads = tape.gradient(loss, variables)
                    optimizers.apply_gradients(zip(grads, variables))
                    # training accuracy
                    y_pred = tf.cast(tf.argmax(out, axis=1), dtype=tf.int32)
                    y_true = tf.cast(tf.argmax(y, axis=1), dtype=tf.int32)
                    correct = tf.reduce_sum(
                        tf.cast(tf.equal(y_pred, y_true), dtype=tf.int32))
                    total_num += x.shape[0]
                    total_correct += int(correct)
                if step % 100 == 0:
                    print("loss is {}".format(loss))
            training_accuracy = total_correct / total_num

            # validation
            total_num = 0
            total_correct = 0
            for (x, y) in valid_db:
                out = model(x)
                y_pred = tf.argmax(out, axis=1)
                y_pred = tf.cast(y_pred, dtype=tf.int32)
                y_true = tf.argmax(y, axis=1)
                y_true = tf.cast(y_true, dtype=tf.int32)
                correct = tf.cast(tf.equal(y_pred, y_true), dtype=tf.int32)
                correct = tf.reduce_sum(correct)
                total_num += x.shape[0]
                total_correct += int(correct)
            validation_accuracy = total_correct / total_num
            print(
                "epoch:{}, training loss:{:.4f}, training accuracy:{:.4f}, validation accuracy:{:.4f}"
                .format(epoch, training_loss, training_accuracy,
                        validation_accuracy))
            his.append({
                'accuracy': training_accuracy,
                'val_accuracy': validation_accuracy
            })
    return his