Пример #1
0
def run(pth_path):
    device = 'cuda'
    dataset = ImagenetValidationDataset('./val/')
    data_loader = DataLoader(dataset,
                             batch_size=64,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=8)
    model, _ = get_resnet(*name_to_params(pth_path))
    model.load_state_dict(torch.load(pth_path)['resnet'])
    model = model.to(device).eval()
    preds = []
    target = []
    for images, labels in tqdm(data_loader):
        _, pred = model(images.to(device), apply_fc=True).topk(1, dim=1)
        preds.append(pred.squeeze(1).cpu())
        target.append(labels)
    p = torch.cat(preds).numpy()
    t = torch.cat(target).numpy()
    all_counters = [Counter() for i in range(1000)]
    for i in range(50000):
        all_counters[t[i]][p[i]] += 1
    total_correct = 0
    for i in range(1000):
        total_correct += all_counters[i].most_common(1)[0][1]
    print(f'ACC: {total_correct / 50000 * 100}')
Пример #2
0
def evaluate(config):
    files = config['files']
    dataset = ImageNetEval(files['images_path'], files['labels_path'],
                           files['mappings_path'])
    data_loader = DataLoader(dataset,
                             batch_size=config['evaluate']['batch_size'],
                             shuffle=False,
                             pin_memory=True,
                             num_workers=8)
    model, _ = get_resnet(*name_to_params(files['checkpoint']))
    model.load_state_dict(torch.load(files['checkpoint'])['resnet'])
    model = model.to(config['device']).eval()
    preds = []
    target = []
    for images, labels in tqdm(data_loader):
        _, pred = model(images.to(config['device']),
                        apply_fc=True).topk(config['evaluate']['top_k'], dim=1)
        preds.append(pred.squeeze(1).cpu())
        target += labels
    p = torch.cat(preds).numpy()
    t = np.array(target, dtype=np.int32)
    all_counters = [Counter() for i in range(1000)]
    for i in range(len(dataset)):
        all_counters[t[i]][p[i]] += 1
    total_correct = 0
    for i in range(1000):
        total_correct += all_counters[i].most_common(1)[0][1]
    print(f'Accuracy: {total_correct / len(dataset) * 100}')
Пример #3
0
 def _load_model(self, network=None):
     """ loading model
     """
     with tf.variable_scope("model_network"):
         self.net = get_resnet(self.model_name,
                               num_classes=self.num_classes)(self.input)
         self.net.build_model()
         self.pred = self.net.outputs
Пример #4
0
def resnet_diff_test(layers_num):
    ckpt_file_path = '../model_weights/resnet_v1_'+str(layers_num)+'.ckpt'
    x = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3], name='input_place')
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=tfconfig)
    nets = get_resnet(x, 1000, layers_num, sess)
    ckpt_static = get_tensor_static_val(ckpt_file_path, all_tensors=True, all_tensor_names=True)

    print('###########'*30)
    vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    total_count = 0
    mean_avg = 0.0
    median_avg = 0.0
    std_avg = 0.0

    for var in vars:
        var_name = var.op.name
        var_name_new = var_name
        if '_bn' in var_name:
            var_name_new = var_name_new.replace('_bn', '')
        if 'W_conv2d' in var_name:
            var_name_new = var_name_new.replace('W_conv2d', 'weights')
        if 'b_conv2d' in var_name:
            var_name_new = var_name_new.replace('b_conv2d', 'biases')
        if 'shortcut_conv' in var_name:
            var_name_new = var_name_new.replace('shortcut_conv', 'shortcut')

        if var_name_new in ckpt_static:
            print(var_name_new, end=',    ')
            total_count += 1
            ckpt_s = ckpt_static[var_name_new]
            var_val = sess.run(var)
            mean_diff = np.mean(var_val) - ckpt_s.mean
            mean_avg += mean_diff
            median_diff = np.median(var_val) - ckpt_s.median
            median_avg += median_diff
            std_diff = np.std(var_val) - ckpt_s.std
            std_avg += std_diff
            print('mean_diff: ', mean_diff, 'median_diff: ', median_diff, 'std_diff: ', std_diff)

    print('total_mean_diff', mean_avg/total_count, 'total_mean_diff', median_avg/total_count,
          'total_std_diff', std_avg/total_count)
Пример #5
0
def train(config):
    t_config = config['train']
    files = config['files']
    dataset = ImageNetEval(files['images_path'], files['labels_path'],
                           files['mappings_path'])
    data_loader = DataLoader(dataset,
                             batch_size=config['train']['batch_size'],
                             shuffle=True,
                             pin_memory=True,
                             num_workers=8)
    model, _ = get_resnet(*name_to_params(files['checkpoint']))
    optimizer = optim.Adam(model.fc.parameters(),
                           lr=t_config['lr'],
                           betas=(t_config['beta1'], t_config['beta2']))
    model.load_state_dict(torch.load(files['checkpoint'])['resnet'])
    model = model.to(config['device']).train()
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = 'NA'
    for i in tqdm(range(t_config['epochs'])):
        for images, labels in tqdm(data_loader, pbar=f'Loss: {loss}'):
            _, pred = model(images.to(config['device']), apply_fc=True)
            loss = loss_fn(pred, labels)
            loss.backward()
            optimizer.step()
Пример #6
0
def main(
    expt,
    model_name,
    device,
    gpu_id,
    optimizer,
    arch,
    num_layers,
    n_classes,
    img_size,
    batch_size,
    test_batch_size,
    subset,
    init_w,
    ckpt_g,
    n_epochs,
    lr_clfs,
    weight_decays,
    milestones,
    gamma,
):
    device = torch_device(device, gpu_id[0])
    num_clfs = len([_ for _ in n_classes if _ > 0])
    if arch == 'resnet':
        print('Using resnet')
        Net = get_resnet(num_layers)
    else:
        print('Using {}'.format(arch))
        Net = get_network(arch, num_layers)

    net_G = define_G(cfg.num_channels[expt],
                     cfg.num_channels[expt],
                     64,
                     gpu_id=device)
    clfs = [
        Net(num_channels=cfg.num_channels[expt], num_classes=_).to(device)
        for _ in n_classes if _ > 0
    ]

    if len(gpu_id) > 1:
        net_G = nn.DataParallel(net_G, device_ids=gpu_id)
        clfs = [nn.DataParallel(clf, device_ids=gpu_id) for clf in clfs]

    assert len(clfs) == num_clfs

    print("Loading weights...\n{}".format(ckpt_g))
    net_G.load_state_dict(torch.load(ckpt_g))
    if init_w:
        print("Init weights...")
        for clf in clfs:
            clf.apply(weights_init)

    scheduler = torch.optim.lr_scheduler.MultiStepLR
    if optimizer == 'sgd':
        opt_clfs = [
            torch.optim.SGD(clf.parameters(),
                            lr=lr,
                            momentum=0.9,
                            weight_decay=weight_decays[0])
            for lr, clf in zip(lr_clfs, clfs)
        ]
    elif optimizer == 'adam':
        opt_clfs = [
            torch.optim.SGD(clf.parameters(),
                            lr=lr,
                            weight_decay=weight_decays[0])
            for lr, clf in zip(lr_clfs, clfs)
        ]
    sch_clfs = [
        scheduler(optim, milestones, gamma=gamma) for optim in opt_clfs
    ]

    assert len(opt_clfs) == num_clfs

    criterionNLL = nn.CrossEntropyLoss().to(device)

    train_loader = get_loader(expt,
                              batch_size,
                              True,
                              img_size=img_size,
                              subset=subset)
    valid_loader = get_loader(expt,
                              test_batch_size,
                              False,
                              img_size=img_size,
                              subset=subset)

    template = '{}'.format(model_name)

    loss_history = defaultdict(list)
    acc_history = defaultdict(list)
    for epoch in range(n_epochs):
        logging.info(
            "Train Epoch " +
            ' '.join(["\t Clf: {}".format(_) for _ in range(num_clfs)]))

        for iteration, (image, labels) in enumerate(train_loader, 1):
            real = image.to(device)

            with torch.no_grad():
                X = net_G(real)
            ys = [_.to(device) for _ in labels]

            [opt.zero_grad() for opt in opt_clfs]
            ys_hat = [clf(X) for clf in clfs]
            loss = [criterionNLL(y_hat, y) for y_hat, y in zip(ys_hat, ys)]
            ys_hat = [_.argmax(1, keepdim=True) for _ in ys_hat]
            acc = [
                y_hat.eq(y.view_as(y_hat)).sum().item() / len(y)
                for y_hat, y in zip(ys_hat, ys)
            ]
            [l.backward() for l in loss]
            [opt.step() for opt in opt_clfs]

            iloss = [l.item() for l in loss]
            assert len(iloss) == num_clfs

            logging.info('[{}]({}/{}) '.format(
                epoch,
                iteration,
                len(train_loader),
            ) + ' '.join([
                '\t {:.4f} ({:.2f})'.format(l, a) for l, a in zip(iloss, acc)
            ]))

        loss_history['train_epoch'].append(epoch)
        acc_history['train_epoch'].append(epoch)
        for idx, (l, a) in enumerate(zip(iloss, acc)):
            loss_history['train_M_{}'.format(idx)].append(l)
            acc_history['train_M_{}'.format(idx)].append(a)

        logging.info(
            "Valid Epoch " +
            ' '.join(["\t Clf: {}".format(_) for _ in range(num_clfs)]))

        loss_m_batch = [0 for _ in range(num_clfs)]
        acc_m_batch = [0 for _ in range(num_clfs)]
        for iteration, (image, labels) in enumerate(valid_loader, 1):

            X = net_G(image.to(device))
            ys = [_.to(device) for _ in labels]

            ys_hat = [clf(X) for clf in clfs]
            loss = [criterionNLL(y_hat, y) for y_hat, y in zip(ys_hat, ys)]
            ys_hat = [_.argmax(1, keepdim=True) for _ in ys_hat]
            acc = [
                y_hat.eq(y.view_as(y_hat)).sum().item() / len(y)
                for y_hat, y in zip(ys_hat, ys)
            ]

            iloss = [l.item() for l in loss]
            for idx, (l, a) in enumerate(zip(iloss, acc)):
                loss_m_batch[idx] += l
                acc_m_batch[idx] += a

            logging.info('[{}]({}/{}) '.format(
                epoch,
                iteration,
                len(valid_loader),
            ) + ' '.join([
                '\t {:.4f} ({:.2f})'.format(l, a) for l, a in zip(iloss, acc)
            ]))

        num_samples = len(valid_loader)
        logging.info('[{}](batch) '.format(epoch, ) + ' '.join([
            '\t {:.4f} ({:.2f})'.format(l / num_samples, a / num_samples)
            for l, a in zip(loss_m_batch, acc_m_batch)
        ]))

        num_samples = len(valid_loader)
        loss_history['valid_epoch'].append(epoch)
        acc_history['valid_epoch'].append(epoch)
        for idx, (l, a) in enumerate(zip(loss_m_batch, acc_m_batch)):
            loss_history['valid_M_{}'.format(idx)].append(l / num_samples)
            acc_history['valid_M_{}'.format(idx)].append(a / num_samples)

        [sch.step() for sch in sch_clfs]

    train_loss_keys = [
        _ for _ in loss_history if 'train' in _ and 'epoch' not in _
    ]
    valid_loss_keys = [
        _ for _ in loss_history if 'valid' in _ and 'epoch' not in _
    ]
    train_acc_keys = [
        _ for _ in acc_history if 'train' in _ and 'epoch' not in _
    ]
    valid_acc_keys = [
        _ for _ in acc_history if 'valid' in _ and 'epoch' not in _
    ]

    cols = 5
    rows = len(train_loss_keys) // cols + 1
    fig = plt.figure(figsize=(7 * cols, 5 * rows))
    base = cols * 100 + rows * 10
    for idx, (tr_l, val_l) in enumerate(zip(train_loss_keys, valid_loss_keys)):
        ax = fig.add_subplot(rows, cols, idx + 1)
        ax.plot(loss_history['train_epoch'], loss_history[tr_l], 'b.:')
        ax.plot(loss_history['valid_epoch'], loss_history[val_l], 'bs-.')
        ax.set_xlabel('epochs')
        ax.set_ylabel('loss')
        ax.set_title(tr_l[6:])
        ax.grid()
        if tr_l in acc_history:
            ax2 = plt.twinx()
            ax2.plot(acc_history['train_epoch'], acc_history[tr_l], 'r.:')
            ax2.plot(acc_history['valid_epoch'], acc_history[val_l], 'rs-.')
            ax2.set_ylabel('accuracy')
    fig.subplots_adjust(wspace=0.4, hspace=0.3)
    plt_ckpt = '{}/{}/plots/{}.jpg'.format(cfg.ckpt_folder, expt, model_name)
    logging.info('Plot: {}'.format(plt_ckpt))
    plt.savefig(plt_ckpt, bbox_inches='tight', dpi=80)

    hist_ckpt = '{}/{}/history/{}.pkl'.format(cfg.ckpt_folder, expt,
                                              model_name)
    logging.info('History: {}'.format(hist_ckpt))
    pkl.dump((loss_history, acc_history), open(hist_ckpt, 'wb'))

    for idx, clf in enumerate(clfs):
        model_ckpt = '{}/{}/models/{}_clf_{}.stop'.format(
            cfg.ckpt_folder, expt, model_name, idx)
        logging.info('Model: {}'.format(model_ckpt))
        torch.save(clf.state_dict(), model_ckpt)
Пример #7
0
def main(
    expt,
    model_name,
    device,
    gpu_id,
    optimizer,
    num_layers,
    n_classes,
    img_size,
    batch_size,
    test_batch_size,
    subset,
    init_w,
    load_w,
    ckpt_g,
    ckpt_clfs,
    n_epochs,
    lr_g,
    lr_clfs,
    ei_array,
    weight_decays,
    milestones,
    save_ckpts,
    gamma,
):
    device = torch_device(device, gpu_id[0])
    num_clfs = len([_ for _ in n_classes if _ > 0])
    Net = get_resnet(num_layers)

    net_G = define_G(cfg.num_channels[expt],
                     cfg.num_channels[expt],
                     64,
                     gpu_id=device)
    clfs = [
        Net(num_channels=cfg.num_channels[expt], num_classes=_).to(device)
        for _ in n_classes if _ > 0
    ]

    if len(gpu_id) > 1:
        net_G = nn.DataParallel(net_G, device_ids=gpu_id)
        clfs = [nn.DataParallel(clf, device_ids=gpu_id) for clf in clfs]

    assert len(clfs) == num_clfs

    if load_w:
        print("Loading weights...\n{}".format(ckpt_g))
        net_G.load_state_dict(torch.load(ckpt_g))
        for clf, ckpt in zip(clfs, ckpt_clfs):
            print(ckpt)
            clf.load_state_dict(torch.load(ckpt))
    elif init_w:
        print("Init weights...")
        net_G.apply(weights_init)
        for clf in clfs:
            clf.apply(weights_init)

    if optimizer == 'sgd':
        optim = torch.optim.SGD
    elif optimizer == 'adam':
        optim = torch.optim.Adam
    scheduler = torch.optim.lr_scheduler.MultiStepLR
    opt_G = torch.optim.Adam(net_G.parameters(),
                             lr=lr_g,
                             weight_decay=weight_decays[0])
    opt_clfs = [
        optim(clf.parameters(), lr=lr, weight_decay=weight_decays[1])
        for lr, clf in zip(lr_clfs, clfs)
    ]
    sch_clfs = [
        scheduler(optim, milestones, gamma=gamma) for optim in opt_clfs
    ]

    assert len(opt_clfs) == num_clfs

    criterionGAN = eigan_loss
    criterionNLL = nn.CrossEntropyLoss().to(device)

    train_loader = get_loader(expt,
                              batch_size,
                              True,
                              img_size=img_size,
                              subset=subset)
    valid_loader = get_loader(expt,
                              test_batch_size,
                              False,
                              img_size=img_size,
                              subset=subset)

    template = '{}'.format(model_name)

    loss_history = defaultdict(list)
    acc_history = defaultdict(list)
    for epoch in range(n_epochs):
        logging.info(
            "Train Epoch \t Loss_G " +
            ' '.join(["\t Clf: {}".format(_) for _ in range(num_clfs)]))

        for iteration, (image, labels) in enumerate(train_loader, 1):
            real = image.to(device)
            ys = [
                _.to(device) for _, num_c in zip(labels, n_classes)
                if num_c > 0
            ]

            with torch.no_grad():
                X = net_G(real)

            [opt.zero_grad() for opt in opt_clfs]
            ys_hat = [clf(X) for clf in clfs]
            loss = [criterionNLL(y_hat, y) for y_hat, y in zip(ys_hat, ys)]
            ys_hat = [_.argmax(1, keepdim=True) for _ in ys_hat]
            acc = [
                y_hat.eq(y.view_as(y_hat)).sum().item() / len(y)
                for y_hat, y in zip(ys_hat, ys)
            ]
            [l.backward() for l in loss]
            [opt.step() for opt in opt_clfs]

            iloss = [l.item() for l in loss]
            assert len(iloss) == num_clfs

            X = net_G(real)
            ys_hat = [clf(X) for clf in clfs]
            loss = [criterionNLL(y_hat, y) for y_hat, y in zip(ys_hat, ys)]

            opt_G.zero_grad()
            loss_g = eigan_loss(loss, ei_array)
            loss_g.backward()
            opt_G.step()

            logging.info('[{}]({}/{}) \t {:.4f} '.format(
                epoch, iteration, len(train_loader), loss_g.item()) +
                         ' '.join([
                             '\t {:.4f} ({:.2f})'.format(l, a)
                             for l, a in zip(iloss, acc)
                         ]))

        loss_history['train_epoch'].append(epoch)
        loss_history['train_G'].append(loss_g.item())
        acc_history['train_epoch'].append(epoch)
        for idx, (l, a) in enumerate(zip(iloss, acc)):
            loss_history['train_M_{}'.format(idx)].append(l)
            acc_history['train_M_{}'.format(idx)].append(a)

        logging.info(
            "Valid Epoch \t Loss_G " +
            ' '.join(["\t Clf: {}".format(_) for _ in range(num_clfs)]))

        loss_g_batch = 0
        loss_m_batch = [0 for _ in range(num_clfs)]
        acc_m_batch = [0 for _ in range(num_clfs)]
        for iteration, (image, labels) in enumerate(valid_loader, 1):

            real = image.to(device)
            fake = net_G(real)
            ys = [
                _.to(device) for _, num_c in zip(labels, n_classes)
                if num_c > 0
            ]

            ys_hat = [clf(fake) for clf in clfs]
            loss = [criterionNLL(y_hat, y) for y_hat, y in zip(ys_hat, ys)]
            ys_hat = [_.argmax(1, keepdim=True) for _ in ys_hat]
            acc = [
                y_hat.eq(y.view_as(y_hat)).sum().item() / len(y)
                for y_hat, y in zip(ys_hat, ys)
            ]

            iloss = [l.item() for l in loss]
            for idx, (l, a) in enumerate(zip(iloss, acc)):
                loss_m_batch[idx] += l
                acc_m_batch[idx] += a

            real = image.to(device)
            fake = net_G(real)

            loss_g = eigan_loss(iloss, ei_array)
            loss_g_batch += loss_g

            logging.info('[{}]({}/{}) \t {:.4f} '.format(
                epoch, iteration, len(valid_loader), loss_g) + ' '.join([
                    '\t {:.4f} ({:.2f})'.format(l, a)
                    for l, a in zip(iloss, acc)
                ]))

        num_samples = len(valid_loader)
        logging.info('[{}](batch) \t {:.4f} '.format(
            epoch, loss_g_batch / num_samples) + ' '.join([
                '\t {:.4f} ({:.2f})'.format(l / num_samples, a / num_samples)
                for l, a in zip(loss_m_batch, acc_m_batch)
            ]))

        loss_history['valid_epoch'].append(epoch)
        loss_history['valid_G'].append(loss_g_batch / num_samples)
        acc_history['valid_epoch'].append(epoch)
        for idx, (l, a) in enumerate(zip(loss_m_batch, acc_m_batch)):
            loss_history['valid_M_{}'.format(idx)].append(l / num_samples)
            acc_history['valid_M_{}'.format(idx)].append(a / num_samples)

        for i in range(image.shape[0]):
            j = np.random.randint(0, image.shape[0])
            sample = image[j]
            label = [str(int(_[j])) for _ in labels]
            ax = plt.subplot(2, 4, i + 1)
            ax.axis('off')
            sample = sample.permute(1, 2, 0)
            plt.imshow(sample.squeeze().numpy())
            plt.savefig('{}/{}/validation/tmp.jpg'.format(
                cfg.ckpt_folder, expt))
            ax = plt.subplot(2, 4, 5 + i)
            ax.axis('off')
            ax.set_title(" ".join(label))
            sample_G = net_G(sample.clone().permute(
                2, 0, 1).unsqueeze_(0).to(device))
            sample_G = sample_G.cpu().detach().squeeze()
            if sample_G.shape[0] == 3:
                sample_G = sample_G.permute(1, 2, 0)
            plt.imshow(sample_G.numpy())

            if i == 3:
                validation_plt = '{}/{}/validation/{}_{}.jpg'.format(
                    cfg.ckpt_folder, expt, model_name, epoch)
                print('Saving: {}'.format(validation_plt))
                plt.tight_layout()
                plt.savefig(validation_plt)
                break

        if epoch in save_ckpts:
            model_ckpt = '{}/{}/models/{}_g_{}.stop'.format(
                cfg.ckpt_folder, expt, model_name, epoch)
            logging.info('Model: {}'.format(model_ckpt))
            torch.save(net_G.state_dict(), model_ckpt)

        [sch.step() for sch in sch_clfs]

    train_loss_keys = [
        _ for _ in loss_history if 'train' in _ and 'epoch' not in _
    ]
    valid_loss_keys = [
        _ for _ in loss_history if 'valid' in _ and 'epoch' not in _
    ]
    train_acc_keys = [
        _ for _ in acc_history if 'train' in _ and 'epoch' not in _
    ]
    valid_acc_keys = [
        _ for _ in acc_history if 'valid' in _ and 'epoch' not in _
    ]

    cols = 5
    rows = len(train_loss_keys) // cols + 1
    fig = plt.figure(figsize=(7 * cols, 5 * rows))
    base = cols * 100 + rows * 10
    for idx, (tr_l, val_l) in enumerate(zip(train_loss_keys, valid_loss_keys)):
        ax = fig.add_subplot(rows, cols, idx + 1)
        ax.plot(loss_history['train_epoch'], loss_history[tr_l], 'b.:')
        ax.plot(loss_history['valid_epoch'], loss_history[val_l], 'bs-.')
        ax.set_xlabel('epochs')
        ax.set_ylabel('loss')
        ax.set_title(tr_l[6:])
        ax.grid()
        if tr_l in acc_history:
            ax2 = plt.twinx()
            ax2.plot(acc_history['train_epoch'], acc_history[tr_l], 'r.:')
            ax2.plot(acc_history['valid_epoch'], acc_history[val_l], 'rs-.')
            ax2.set_ylabel('accuracy')
    fig.subplots_adjust(wspace=0.4, hspace=0.3)
    plt_ckpt = '{}/{}/plots/{}.jpg'.format(cfg.ckpt_folder, expt, model_name)
    logging.info('Plot: {}'.format(plt_ckpt))
    plt.savefig(plt_ckpt, bbox_inches='tight', dpi=80)

    hist_ckpt = '{}/{}/history/{}.pkl'.format(cfg.ckpt_folder, expt,
                                              model_name)
    logging.info('History: {}'.format(hist_ckpt))
    pkl.dump((loss_history, acc_history), open(hist_ckpt, 'wb'))

    model_ckpt = '{}/{}/models/{}_g.stop'.format(cfg.ckpt_folder, expt,
                                                 model_name)
    logging.info('Model: {}'.format(model_ckpt))
    torch.save(net_G.state_dict(), model_ckpt)

    for idx, clf in enumerate(clfs):
        model_ckpt = '{}/{}/models/{}_clf_{}.stop'.format(
            cfg.ckpt_folder, expt, model_name, idx)
        logging.info('Model: {}'.format(model_ckpt))
        torch.save(clf.state_dict(), model_ckpt)
Пример #8
0
 def get_backbone(self, data, cfg):
     strides = [1, 2, 2, 1]
     num_layers = cfg.network.NUM_LAYERS
     return get_resnet(data, num_layers, strides)
Пример #9
0
    img.fill(0)
    cv2.imwrite(path, img)
    return path


"""
################ The main code flow to call approriate functions ##########################
"""
if len(sys.argv) < 2:
    sys.exit(0)
cmd = sys.argv[1]
if cmd == "1":
    binary_model = get_unet(n_filters=16, dropout=0.05, batchnorm=True)
    modelname = "unet"
elif cmd == "2":
    binary_model = get_resnet(f=16, bn_axis=3, classes=1)
    modelname = "resnet"
elif cmd == "3":
    binary_model = get_segnet()
    modelname = "segnet"
else:
    binary_model = get_deeplab()
    modelname = "deeplab"
X_train, Y_train = get_class_for_generator("Train")
X_train, X_test, y_train, y_test = train_test_split(X_train,
                                                    Y_train,
                                                    test_size=0.30)
X_train = np.array(X_train)
X_test = np.array(X_test)
y_train = np.array(y_train)
y_test = np.array(y_test)
Пример #10
0
def main():
    use_ema_model = args.ema
    prefix = ('ema_model/' if use_ema_model else '') + 'base_model/'
    head_prefix = ('ema_model/' if use_ema_model else '') + 'head_contrastive/'
    # 1. read tensorflow weight into a python dict
    vars_list = []
    contrastive_vars = []
    for v in tf.train.list_variables(args.tf_path):
        if v[0].startswith(prefix) and not v[0].endswith('/Momentum'):
            vars_list.append(v[0])
        elif v[0] in {
                'head_supervised/linear_layer/dense/bias',
                'head_supervised/linear_layer/dense/kernel'
        }:
            vars_list.append(v[0])
        elif v[0].startswith(head_prefix) and not v[0].endswith('/Momentum'):
            contrastive_vars.append(v[0])

    sd = {}
    ckpt_reader = tf.train.load_checkpoint(args.tf_path)
    for v in vars_list:
        sd[v] = ckpt_reader.get_tensor(v)

    split_idx = 2 if use_ema_model else 1
    # 2. convert the state_dict to PyTorch format
    conv_keys = [
        k for k in sd.keys()
        if k.split('/')[split_idx].split('_')[0] == 'conv2d'
    ]
    conv_idx = []
    for k in conv_keys:
        mid = k.split('/')[split_idx]
        if len(mid) == 6:
            conv_idx.append(0)
        else:
            conv_idx.append(int(mid[7:]))
    arg_idx = np.argsort(conv_idx)
    conv_keys = [conv_keys[idx] for idx in arg_idx]

    bn_keys = list(
        set([
            k.split('/')[split_idx] for k in sd.keys()
            if k.split('/')[split_idx].split('_')[0] == 'batch'
        ]))
    bn_idx = []
    for k in bn_keys:
        if len(k.split('_')) == 2:
            bn_idx.append(0)
        else:
            bn_idx.append(int(k.split('_')[2]))
    arg_idx = np.argsort(bn_idx)
    bn_keys = [bn_keys[idx] for idx in arg_idx]

    depth, width, sk_ratio = name_to_params(args.tf_path)
    model, head = get_resnet(depth, width, sk_ratio)

    conv_op = []
    bn_op = []
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            conv_op.append(m)
        elif isinstance(m, nn.BatchNorm2d):
            bn_op.append(m)
    assert len(vars_list) == (len(conv_op) + len(bn_op) * 4 + 2)  # 2 for fc

    for i_conv in range(len(conv_keys)):
        m = conv_op[i_conv]
        w = torch.from_numpy(sd[conv_keys[i_conv]]).permute(3, 2, 0, 1)
        assert w.shape == m.weight.shape, f'size mismatch {w.shape} <> {m.weight.shape}'
        m.weight.data = w

    for i_bn in range(len(bn_keys)):
        m = bn_op[i_bn]
        gamma = torch.from_numpy(sd[prefix + bn_keys[i_bn] + '/gamma'])
        assert m.weight.shape == gamma.shape, f'size mismatch {gamma.shape} <> {m.weight.shape}'
        m.weight.data = gamma
        m.bias.data = torch.from_numpy(sd[prefix + bn_keys[i_bn] + '/beta'])
        m.running_mean = torch.from_numpy(sd[prefix + bn_keys[i_bn] +
                                             '/moving_mean'])
        m.running_var = torch.from_numpy(sd[prefix + bn_keys[i_bn] +
                                            '/moving_variance'])

    w = torch.from_numpy(sd['head_supervised/linear_layer/dense/kernel']).t()
    assert model.fc.weight.shape == w.shape
    model.fc.weight.data = w
    b = torch.from_numpy(sd['head_supervised/linear_layer/dense/bias'])
    assert model.fc.bias.shape == b.shape
    model.fc.bias.data = b

    if args.supervised:
        save_location = f'r{depth}_{width}x_sk{1 if sk_ratio != 0 else 0}{"_ema" if use_ema_model else ""}.pth'
        torch.save({
            'resnet': model.state_dict(),
            'head': head.state_dict()
        }, save_location)
        return
    sd = {}
    for v in contrastive_vars:
        sd[v] = ckpt_reader.get_tensor(v)
    linear_op = []
    bn_op = []
    for m in head.modules():
        if isinstance(m, nn.Linear):
            linear_op.append(m)
        elif isinstance(m, nn.BatchNorm1d):
            bn_op.append(m)
    for i, (l, m) in enumerate(zip(linear_op, bn_op)):
        l.weight.data = torch.from_numpy(
            sd[f'{head_prefix}nl_{i}/dense/kernel']).t()
        common_prefix = f'{head_prefix}nl_{i}/batch_normalization/'
        m.weight.data = torch.from_numpy(sd[f'{common_prefix}gamma'])
        if i != 2:
            m.bias.data = torch.from_numpy(sd[f'{common_prefix}beta'])
        m.running_mean = torch.from_numpy(sd[f'{common_prefix}moving_mean'])
        m.running_var = torch.from_numpy(sd[f'{common_prefix}moving_variance'])

    # 3. dump the PyTorch weights.
    save_location = f'r{depth}_{width}x_sk{1 if sk_ratio != 0 else 0}{"_ema" if use_ema_model else ""}.pth'
    torch.save({
        'resnet': model.state_dict(),
        'head': head.state_dict()
    }, save_location)
Пример #11
0
import torch.utils.data as data
from torchvision import transforms

from resnet_data.inceptionresnetv2.pytorch_load import inceptionresnetv2, InceptionResnetV2
from cloud_bm_v2 import ToTensor, Normalization, AmazonDataSet, read_data, train, Scale, RandomHorizontalFlip, RandomVerticalFlip, RandomSizedCrop
from resnet import get_resnet

if __name__ == "__main__":
    print("Started")
    parser = argparse.ArgumentParser()
    parser.add_argument("--load_weights", default=None, type=str)
    parser.add_argument("--img_dir", default="train/train-jpg/", type=str)
    args = parser.parse_args()
    batch_size = 17

    in_res = get_resnet([0], 4, sigmoid=False, dropout=True)
    print("Batch size: {}".format(batch_size))

    data_transform = transforms.Compose([
        Scale(),
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        #RandomSizedCrop(),
        ToTensor(),
        Normalization(),
    ])
    val_transform = transforms.Compose([Scale(), ToTensor(), Normalization()])

    if args.load_weights:
        in_res.load_state_dict(torch.load(args.load_weights))
        print("Loaded weights from {}".format(args.load_weights))