示例#1
0
def defineNet(net,
              gpu_ids=(0, 1, 2, 3),
              use_weights_init=True,
              use_checkpoint=False,
              show_structure=False):
    """define network, according to CPU, GPU and multi-GPUs.

    :param net: Network, type of module.
    :param gpu_ids: Using GPUs' id, type of tuple. If not use GPU, pass '()'.
    :param use_weights_init: If init weights ( method of Hekaiming init).
    :return: Network
    """
    # assert isinstance(Module, net), "type %s is not 'mudule' type"% type(net)
    print_network(net, show_structure)
    gpu_available = torch.cuda.is_available()
    model_name = type(net)
    if (len(gpu_ids) == 1) & gpu_available:
        net = net.cuda(gpu_ids[0])
        print("%s model use GPU(%d)!" % (model_name, gpu_ids[0]))
    elif (len(gpu_ids) > 1) & gpu_available:
        net = DataParallel(net.cuda(), gpu_ids)
        print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids))
    else:
        print("%s model use CPU!" % (model_name))

    if use_weights_init:
        net.apply(weightsInit)
        print("apply weight init!")
    if use_checkpoint:
        net = loadModel(model_path, model_weights_path, gpus=None).train()
    return net
示例#2
0
    def define(self,
               proto_model,
               gpu_ids,
               use_weights_init=True,
               show_structure=False):
        """define network, according to CPU, GPU and multi-GPUs.

        :param proto_model: Network, type of module.
        :param gpu_ids: Using GPUs' id, type of tuple. If not use GPU, pass '()'.
        :param use_weights_init: If init weights ( method of Hekaiming init).
        :return: Network
        """
        # assert isinstance(Module, net), "type %s is not 'mudule' type"% type(net)
        self.print_network(proto_model, show_structure)
        gpu_available = torch.cuda.is_available()
        model_name = type(proto_model)
        if (len(gpu_ids) == 1) & gpu_available:
            proto_model = proto_model.cuda(gpu_ids[0])
            print("%s model use GPU(%d)!" % (model_name, gpu_ids[0]))
        elif (len(gpu_ids) > 1) & gpu_available:
            proto_model = DataParallel(proto_model.cuda(), gpu_ids)
            print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids))
        else:
            print("%s model use CPU!" % (model_name))

        if use_weights_init:
            proto_model.apply(self._weightsInit)
            print("apply weight init!")
        self.model = proto_model
示例#3
0
class rpn_initor():
    def __init__(self):

        self.lr1 = 0.15
        self.max_pre = 0
        self.max_acc = 0
        self.max_recall = 0
        self.batch = 240
        # os.environ["CUDA_VISIBLE_DEVICES"] = match[self.seed]
        os.environ["CUDA_VISIBLE_DEVICES"] = '0'

        self.features = mb().eval()

        path = os.path.join(os.getcwd(), "base_max.p")
        tmp = load(path)
        print(path)

        self.features.load_state_dict(tmp)
        self.RPN = RPN()
        # get_paprams(self.RPN)
        self.features = self.features.cuda()
        self.RPN = self.RPN.cuda()

        self.RPN.apply(weights_init)
        self.features = DataParallel(self.features, device_ids=[0])
        self.RPN = DataParallel(self.RPN, device_ids=[0])
示例#4
0
class rcn_initor():
    def __init__(self):
        # para.device_ids = [0, 1, 2]

        self.pool_size = [8, 4, 2, 1]
        # os.environ["CUDA_VISIBLE_DEVICES"] = match.get(self.seed)
        os.environ["CUDA_VISIBLE_DEVICES"] = '0'

        self.tool = rpn_tool_d()
        self.tool2 = rcn_tool_c()
        path_base = os.path.join(os.getcwd(), 'base_a1_max.p')
        path_RPN = os.path.join(os.getcwd(), 'rpn_a1_max.p')

        print(path_base)
        print(path_RPN)
        from tool.batch.roi_layers import ROIPool
        self.pool4 = ROIPool(self.pool_size[3], 1 / 64)
        self.pool3 = ROIPool(self.pool_size[2], 1 / 32)
        self.pool2 = ROIPool(self.pool_size[1], 1 / 16)
        self.pool1 = ROIPool(self.pool_size[0], 1 / 8)
        self.pre = pre().cuda()
        self.ROI = roi().cuda()
        self.RPN = RPN().cuda().eval()

        self.features = mb().cuda().eval()
        tmp = load(path_RPN)
        self.RPN.load_state_dict(tmp)
        tmp = load(path_base)
        self.features.load_state_dict(tmp)
        self.features = DataParallel(self.features, device_ids=[0])
        self.RPN = DataParallel(self.RPN, device_ids=[0])
        self.ROI = DataParallel(self.ROI, device_ids=[0])
        self.pre = DataParallel(self.pre, device_ids=[0])

        get_paprams(self.features)
        get_paprams(self.RPN)
        get_paprams(self.ROI)
        self.ROI.apply(weights_init)
        self.batch = True
        self.flag = 3
示例#5
0
class model():
    def __int__(self):
        pass

    def train(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        device_ids = [0]
        self.classifier = classifier()
        get_paprams(self.classifier)
        get_paprams(self.classifier.base)
        # data_set_eval = my_dataset(eval=True)
        # data_set = my_dataset_10s()
        # data_set_test = my_dataset_10s()
        data_set = my_dataset_10s_smote()
        data_set_test = my_dataset_10s_smote(test=True, all_data=data_set.all_data, all_label=data_set.all_label,
                                             index_=data_set.index)
        # data_set_eval = my_dataset_10s(eval=True)
        # data_set_combine = my_dataset(combine=True)
        batch = 300
        totoal_epoch = 2000
        print('batch:{}'.format(batch))
        # self.evaluation = evaluation
        data_loader = DataLoader(data_set, batch, shuffle=True, collate_fn=detection_collate)
        data_loader_test = DataLoader(data_set_test, batch, False, collate_fn=detection_collate)
        # data_loader_eval = DataLoader(data_set_eval, batch, False, collate_fn=detection_collate)
        self.classifier = self.classifier.cuda()
        self.classifier = DataParallel(self.classifier, device_ids=device_ids)
        optim = Adadelta(self.classifier.parameters(), 0.1, 0.9, weight_decay=1e-5)

        self.cretion = smooth_focal_weight()

        self.classifier.apply(weights_init)
        start_time = time.time()
        count = 0
        epoch = -1
        while 1:
            epoch += 1
            runing_losss = [0] * 5
            for data in data_loader:
                loss = [0] * 5
                y = data[1].cuda()
                x = data[0].cuda()
                optim.zero_grad()

                weight = torch.Tensor([0.5, 2, 0.5, 2]).cuda()

                inputs, targets_a, targets_b, lam = mixup_data(x, y)
                predict = self.classifier(x)
                ############################3

                loss_func = mixup_criterion(targets_a, targets_b, lam, weight)
                loss5 = loss_func(self.cretion, predict[0])
                loss4 = loss_func(self.cretion, predict[1]) * 0.4
                loss3 = loss_func(self.cretion, predict[2]) * 0.3
                loss2 = loss_func(self.cretion, predict[3]) * 0.2
                loss1 = loss_func(self.cretion, predict[4]) * 0.1

                tmp = loss5 + loss4 + loss3 + loss2 + loss1

                # tmp = sum(loss)
                tmp.backward()
                optim.step()
                for i in range(5):
                    # runing_losss[i] += (loss[i].item())
                    runing_losss[i] += (tmp.item())

                count += 1
                # torch.cuda.empty_cache()
            end_time = time.time()
            print(
                "epoch:{a}: loss:{b} spend_time:{c} time:{d}".format(a=epoch, b=sum(runing_losss),
                                                                     c=int(end_time - start_time),
                                                                     d=time.asctime()))
            start_time = end_time

            # vis.line(np.asarray([optim.param_groups[0]['lr']]), np.asarray([epoch]), win="lr", update='append',
            #          opts=dict(title='lr'))
            # if (epoch > 20):
            #     runing_losss = np.asarray(runing_losss).reshape(1, 5)

            # vis.line(runing_losss,
            #          np.asarray([epoch] * 5).reshape(1, 5), win="loss-epoch", update='append',
            #          opts=dict(title='loss', legend=['loss1', 'loss2', 'loss3', 'loss4', 'loss5', 'loss6']))
            save(self.classifier.module.base.state_dict(),
                 str(epoch) + 'base_c2.p')
            save(self.classifier.module.state_dict(),
                 str(epoch) + 'base_all_c2.p')
            # print('eval:{}'.format(time.asctime(time.localtime(time.time()))))
            self.classifier.eval()
            # self.evaluation(self.classifier, data_loader_eval)
            # print('test:{}'.format(time.asctime(time.localtime(time.time()))))
            # self.evaluation(self.classifier, data_loader_eval, epoch)
            self.evaluation(self.classifier, data_loader_test, epoch)
            # self.evaluation(self.classifier, data_loader, epoch)

            # print('combine:{}'.format(time.asctime(time.localtime(time.time()))))
            # evaluation(self.classifier, data_loader_combine)
            self.classifier.train()
            if epoch % 10 == 0:
                adjust_learning_rate(optim, 0.9, epoch, totoal_epoch, 0.1)
        # print('eval')
        # self.evaluation(self.classifier, data_loader_eval)
        # self.evaluation(self.classifier, data_loader_test, 500)

    def evaluation(self, classifier, data_loader_test, epoch):
        # classifier.eval()
        all_predict = [[], [], [], [], []]
        all_ground = []
        with torch.no_grad():
            for data in data_loader_test:
                y = data[1].cuda()
                x = data[0].cuda()
                predict_list = classifier(x)
                # predict = F.softmax(predict_list[0], 1)
                for i in range(5):
                    predict, index = torch.max(predict_list[i], 1)
                    all_predict[i].extend(index.tolist())

                all_ground.extend(y.tolist())
        # weight = [0.3, 1, 0.3, 1]
        # weight = [i for i in all_ground]
        print("Accuracy:{}".format(metrics.accuracy_score(all_ground, all_predict[0])))
        print('precesion:{}'.format(metrics.precision_score(all_ground, all_predict[0], average=None)))
        print('recall:{}'.format(metrics.recall_score(all_ground, all_predict[0], average=None)))
        print('f-score:{}'.format(metrics.f1_score(all_ground, all_predict[0], average=None)))
        print("{}".format(metrics.confusion_matrix(all_ground, all_predict[0])))
        for i in range(5):
            tmp = metrics.accuracy_score(all_ground, all_predict[i])
            tmp2 = metrics.precision_score(all_ground, all_predict[i], average=None)
            tmp3 = metrics.recall_score(all_ground, all_predict[i], average=None)
            tmp4 = metrics.f1_score(all_ground, all_predict[i], average=None)
            print("Accuracy:{}".format(tmp))
            print('precesion:{}'.format(tmp2))
            print('recall:{}'.format(tmp3))
            print('f-score:{}'.format(tmp4))

    def test(self):
        self.classifier = classifier()
        self.classifier = self.classifier.cuda()
        data_set = my_dataset_10s_smote(test=True)
        data_loader_test = DataLoader(data_set, 300, False, collate_fn=detection_collate)
        all_predict = []
        all_ground = []
        self.classifier.eval()
        self.classifier.base.eval()
        total = 0
        with torch.no_grad():
            for data in data_loader_test:
                y = data[1].cuda()
                x = data[0].cuda()
                every_len = data[2]
                max_len = data[3]
                predict = self.classifier(x, every_len, max_len)[0]
                # predict = F.softmax(predict, 1)
                predict, index = torch.max(predict, 1)
                # total += predict.sum().item()
                ########
                ###
                all_predict.extend(list(index.cpu().numpy()))
                all_ground.extend(list(y.cpu().numpy()))
        # print(sum(all_predict))
        # print(sum(all_ground))
        print(metrics.precision_score(all_ground, all_predict, average=None))
        print(metrics.recall_score(all_ground, all_predict, average=None))
        print(metrics.f1_score(all_ground, all_predict, average=None))
        print(metrics.confusion_matrix(all_ground, all_predict))
示例#6
0
# Initialize generator and discriminator
generator = my_model.Generator()
discriminator = my_model.Discriminator()

device_for_data = torch.device('cuda:0' if cuda else 'cpu')
device_for_model = torch.device('cuda' if cuda else 'cpu')

if cuda:
    generator = DataParallel(generator)
    generator.to(device_for_model)
    discriminator = DataParallel(discriminator)
    discriminator.to(device_for_model)

# Initialize weights
generator.apply(my_model.weights_init_normal)
discriminator.apply(my_model.weights_init_normal)

transform_train = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])
dataset = my_dataset.MyDataset(data_path=data_dir, transform=transform_train)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=0)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt.lr,
示例#7
0
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = True
    net = DataParallel(net)
    datadir = config_detector['preprocess_result_path']
    print 'datadir = ', datadir

    net = DataParallel(net, device_ids=[0])

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find('Linear') != -1:
            m.bias.data.fill_(0)

    # Cross-Validation of 3D-semi, train
    k_fold = args.fold
    print "Authorizing fold: {:d}".format(k_fold)

    # Loading training set
    dataset = data.DataBowl3Detector(
        datadir,
        'detector/luna_file_id/subset_fold{:d}'.format(k_fold) +
        '/file_id_rpn_train.npy',
        config,
        phase='train')
    rpn_train_loader = DataLoader(dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.workers,
                                  pin_memory=True)

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    # Training process
    train_loss_l, train_tpr_l = [], []

    # weights initialize
    net.apply(weights_init)

    for epoch in range(start_epoch, args.epochs + 1):
        if not os.path.exists(os.path.join(save_dir,
                                           'fold{:d}'.format(k_fold))):
            os.makedirs(os.path.join(save_dir, 'fold{:d}'.format(k_fold)))
        train_loss, train_tpr = train(
            rpn_train_loader, net, loss, epoch, optimizer, get_lr,
            args.save_freq, os.path.join(save_dir, 'fold{:d}'.format(k_fold)))

        # Append loss results
        train_loss_l.append(train_loss)
        train_tpr_l.append(train_tpr)

    # Save Train-Validation results
    if not os.path.exists('./train-vali-results/fold{:d}'.format(k_fold)):
        os.makedirs('./train-vali-results/fold{:d}'.format(k_fold))
    np.save(
        './train-vali-results/fold{:d}'.format(k_fold) + '/rpn-train-loss.npy',
        np.asarray(train_loss_l).astype(np.float64))
    np.save(
        './train-vali-results/fold{:d}'.format(k_fold) + '/rpn-train-tpr.npy',
        np.asarray(train_tpr_l).astype(np.float64))

    # Testing process
    if args.test == 1:
        margin = 32
        sidelen = 144

        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(
            datadir,
            'detector/luna_file_id/subset_fold{:d}'.format(k_fold) +
            '/file_id_test.npy',
            config,
            phase='test',
            split_comber=split_comber)
        test_loader = DataLoader(
            dataset,
            batch_size=1,  # 在测试阶段,batch size 固定为1
            shuffle=False,
            num_workers=args.workers,
            collate_fn=data.collate,
            pin_memory=False)

        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(
            datadir,
            'detector/luna_file_id/subset_fold{:d}'.format(k_fold) +
            '/file_id_total_train.npy',
            config,
            phase='test',
            split_comber=split_comber)

        train_total_loader = DataLoader(
            dataset,
            batch_size=1,  # 在测试阶段,batch size 固定为1
            shuffle=False,
            num_workers=args.workers,
            collate_fn=data.collate,
            pin_memory=False)

        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(
            datadir,
            'detector/luna_file_id/file_id_unlabel.npy',
            config,
            phase='test',
            split_comber=split_comber)

        unlabel_loader = DataLoader(
            dataset,
            batch_size=1,  # 在测试阶段,batch size 固定为1
            shuffle=False,
            num_workers=args.workers,
            collate_fn=data.collate,
            pin_memory=False)

        test_dir = os.path.join(save_dir, 'voi_fold{:d}'.format(k_fold),
                                'test')
        if not os.path.exists(test_dir):
            os.makedirs(test_dir)
        find_voi(test_loader, net, get_pbb, test_dir, config)

        total_train_dir = os.path.join(save_dir, 'voi_fold{:d}'.format(k_fold),
                                       'total_train')
        if not os.path.exists(total_train_dir):
            os.makedirs(total_train_dir)
        find_voi(train_total_loader, net, get_pbb, total_train_dir, config)

        unlabel_dir = os.path.join(save_dir, 'voi_fold{:d}'.format(k_fold),
                                   'unlabel')
        if not os.path.exists(unlabel_dir):
            os.makedirs(unlabel_dir)
        find_voi(unlabel_loader, net, get_pbb, unlabel_dir, config)
示例#8
0
class model(base_process):
    def train_stage_2(self):

        batch = 240
        lr1 = 0.15

        data_set = loader(os.path.join(os.getcwd(), 'data_2'), {"mode": "training"})
        data_set_test = loader(os.path.join(os.getcwd(), 'data_2'),{"mode": "test"}, data_set.index)
        data_set_eval = loader(os.path.join(os.getcwd(), 'data_2'),{"mode": "eval"}, data_set.index)

        data_loader = DataLoader(data_set, batch, True, collate_fn=call_back.detection_collate_RPN)
        data_loader_test = DataLoader(data_set_test, batch, False, collate_fn=call_back.detection_collate_RPN)
        data_loader_eval = DataLoader(data_set_eval, batch, False, collate_fn=call_back.detection_collate_RPN)

        # optim = Adadelta(self.ROI.parameters(), lr=lr1, weight_decay=1e-5)
        start_time = time.time()
        optim_a = Adadelta([{'params': self.pre.parameters()},
                            {'params': self.ROI.parameters()}], lr=0.15, weight_decay=1e-5)
        cfg.test = False
        count = 0
        for epoch in range(200):
            runing_losss = 0.0
            cls_loss = 0
            coor_loss = 0
            cls_loss2 = 0
            coor_loss2 = 0
            count += 1
            # base_time = RPN_time = ROI_time = nms_time = pre_gt = loss_time = linear_time = 0
            for data in data_loader:
                y = data[1]
                x = data[0].cuda()
                peak = data[2]
                num = data[3]
                optim_a.zero_grad()

                with torch.no_grad():
                    if self.flag >= 2:
                        result = self.base_process(x, y, peak)
                        feat1 = result['feat_8']
                        feat2 = result['feat_16']
                        feat3 = result['feat_32']
                        feat4 = result['feat_64']
                        label = result['label']
                        loss_box = result['loss_box']
                        cross_entropy = result['cross_entropy']

                cls_score = self.pre(feat1, feat2, feat3, feat4)
                cls_score = self.ROI(cls_score)

                cross_entropy2 = self.tool2.cal_loss2(cls_score, label)

                loss_total = cross_entropy2
                loss_total.backward()
                optim_a.step()
                runing_losss += loss_total.item()
                cls_loss2 += cross_entropy2.item()
                cls_loss += cross_entropy.item()
                coor_loss += loss_box.item()
            end_time = time.time()
            torch.cuda.empty_cache()
            print(
                "epoch:{a} time:{ff}: loss:{b:.4f} cls:{d:.4f} cor{e:.4f} cls2:{f:.4f} cor2:{g:.4f} date:{fff}".format(
                    a=epoch,
                    b=runing_losss,
                    d=cls_loss,
                    e=coor_loss,
                    f=cls_loss2,
                    g=coor_loss2, ff=int(end_time - start_time),
                    fff=time.asctime()))
            # if epoch % 10 == 0:
            #     adjust_learning_rate(optim, 0.9, epoch, 50, lr1)
            p = None

            # if epoch % 2 == 0:
            #     print("test result")
            # save(self.RPN.module.state_dict(),
            #      os.path.join(os.getcwd(), str(epoch) + 'rpn_a2.p'))
            # save(self.RPN.module.state_dict(),
            #      os.path.join(os.getcwd(), str(epoch) + 'base_a2.p'))
            start_time = end_time
        all_data = []
        all_label = []
        for data in data_loader:
            y = data[1]
            x = data[0].cuda()
            num = data[3]
            peak = data[2]
            with torch.no_grad():
                if self.flag >= 2:
                    result = self.base_process_2(x, y, peak)
                    data_ = result['x']
                    label = result['label']
                    loss_box = result['loss_box']
                    cross_entropy = result['cross_entropy']
                    all_data.extend(data_.cpu())
                    all_label.extend(label.cpu())
        for data in data_loader_eval:
            y = data[1]
            x = data[0].cuda()
            num = data[3]
            peak = data[2]
            with torch.no_grad():
                if self.flag >= 2:
                    result = self.base_process_2(x, y, peak)
                    data_ = result['x']
                    label = result['label']
                    loss_box = result['loss_box']
                    cross_entropy = result['cross_entropy']
                    all_data.extend(data_.cpu())
                    all_label.extend(label.cpu())
        for data in data_loader_test:
            y = data[1]
            x = data[0].cuda()
            num = data[3]
            peak = data[2]
            with torch.no_grad():
                if self.flag >= 2:
                    result = self.base_process_2(x, y, peak)
                    data_ = result['x']
                    label = result['label']
                    loss_box = result['loss_box']
                    cross_entropy = result['cross_entropy']
                    all_data.extend(data_.cpu())
                    all_label.extend(label.cpu())

        all_data = torch.stack(all_data, 0).numpy()
        all_label = torch.LongTensor(all_label).numpy()
        from imblearn.over_sampling import SMOTE
        fun = SMOTE()
        all_data, all_label = fun.fit_resample(all_data, all_label)
        total = len(all_label)
        training_label = all_label[:int(0.7 * total)]
        training_data = all_data[:int(0.7 * total)]

        test_label = all_label[-int(0.2 * total):]
        test_data = all_data[-int(0.2 * total):]
        count = 0
        self.ROI = roi().cuda()
        self.ROI = DataParallel(self.ROI, device_ids=[0])
        self.ROI.apply(weights_init)

        optim_b = Adadelta(self.ROI.parameters(), lr=0.15, weight_decay=1e-5)
        for epoch in range(1200):
            runing_losss = 0.0
            cls_loss = 0
            coor_loss = 0
            cls_loss2 = 0
            coor_loss2 = 0
            count += 1
            optim_b.zero_grad()
            optim_a.zero_grad()

            # base_time = RPN_time = ROI_time = nms_time = pre_gt = loss_time = linear_time = 0
            for j in range(int(len(training_label) / 240)):
                data_ = torch.Tensor(training_data[j * 240:j * 240 + 240]).view(240, 1024, 15).cuda()
                label_ = torch.LongTensor(training_label[j * 240:j * 240 + 240]).cuda()
                optim_b.zero_grad()

                cls_score = self.ROI(data_)
                cross_entropy2 = self.tool2.cal_loss2(cls_score, label_)

                loss_total = cross_entropy2
                loss_total.backward()
                optim_b.step()
                runing_losss += loss_total.item()
                cls_loss2 += cross_entropy2.item()
                cls_loss += cross_entropy.item()
                coor_loss += loss_box.item()
            end_time = time.time()
            torch.cuda.empty_cache()
            print(
                "epoch:{a} time:{ff}: loss:{b:.4f} cls:{d:.4f} cor{e:.4f} cls2:{f:.4f} cor2:{g:.4f} date:{fff}".format(
                    a=epoch,
                    b=runing_losss,
                    d=cls_loss,
                    e=coor_loss,
                    f=cls_loss2,
                    g=coor_loss2, ff=int(end_time - start_time),
                    fff=time.asctime()))
            if epoch % 10 == 0 and epoch > 0:
                adjust_learning_rate(optim_b, 0.9, epoch, 50, 0.3)

            p = None
            self.eval_(test_data, test_label)
            # self.ROI_eval(data_loader_eval, {"epoch": epoch})

            start_time = end_time
        print('finish')

    def eval_(self, data, label):
        self.ROI = self.ROI.eval()
        gt = []
        pre = []
        total = int(len(label) / 240)
        with torch.no_grad():
            for i in range(total):
                a = i * 240
                b = a + 240
                sin_x = torch.Tensor(data[a:b]).cuda()
                sin_x = sin_x.view(240, 1024, 15)
                sin_y = label[a:b]
                predict = self.ROI(sin_x)
                predict, index = torch.max(predict, 1)
                pre.extend(index.cpu().tolist())
                gt.extend(sin_y)
        print("ppv:{}".format(metrics.precision_score(gt, pre, average='micro')))
        print("spe:{}".format(specificity_score(gt, pre, average='micro')))
        print("sen:{}".format(metrics.recall_score(gt, pre, average='micro')))


    def base_process_2(self, x, y, peak):
        cross_entropy, loss_box = torch.ones(1), torch.ones(1)
        with torch.no_grad():
            x1, x2, x3, x4 = self.features(x)
            if self.flag == 3:
                predict_confidence, box_predict = self.RPN(x1, x2, x3, x4)
                proposal, batch_offset, batch_conf = self.tool.get_proposal(predict_confidence, box_predict,
                                                                            y, test=True)
                # save_proposal = [i.cpu().numpy() for i in proposal]
                # save_data = x.cpu().numpy()
                # save_y = [i.numpy() for i in y]
                # self.save_dict['data'].append(save_data)
                # self.save_dict['label'].append(save_y)
                # self.save_dict['predict'].append(save_proposal)

            proposal, label = self.tool2.pre_gt_match_uniform(proposal, y, training=True, params={'peak': peak})

            if 1:
                for i in range(len(proposal)):
                    tmp = torch.zeros(proposal[i].size()[0], 1).fill_(
                        i).cuda()
                    proposal[i] = torch.cat([tmp, proposal[i]], 1)
                proposal = torch.cat(proposal, 0)

            feat4, label, class_num = self.tool2.roi_pooling_cuda(x4, proposal, label=label, stride=64,
                                                                  pool=self.pool4,
                                                                  batch=True)
            feat3 = \
                self.tool2.roi_pooling_cuda(x3, proposal, stride=64, pool=self.pool3,
                                            batch=True, label=None)[
                    0]
            feat2 = \
                self.tool2.roi_pooling_cuda(x2, proposal, stride=32,
                                            pool=self.pool2,
                                            batch=True, label=None)[0]
            feat1 = \
                self.tool2.roi_pooling_cuda(x1, proposal, stride=16,
                                            pool=self.pool1,
                                            batch=True, label=None, )[0]

            x = self.pre(feat1, feat2, feat3, feat4)
            x = x.view(-1, 1024 * 15)
            if self.flag == 2:
                result = {}
                result['x'] = x
                result['label'] = label
                result['predict_offset'] = 0
                result['class_num'] = class_num
                result['batch_cor_weight'] = 0
                result['cross_entropy'] = cross_entropy
                result['loss_box'] = loss_box
                return result
            elif self.flag == 3:
                result = {}
                result['x'] = x
                result['label'] = label
                result['class_num'] = class_num
                result['cross_entropy'] = cross_entropy
                result['loss_box'] = loss_box

                return result
class ImagenetExperiment:
    """
    Experiment class used to train Sparse and dense versions of Resnet50 v1.5
    models on Imagenet dataset
    """
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.loss_function = None
        self.lr_scheduler = None
        self.train_loader = None
        self.val_loader = None
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.batches_in_epoch = sys.maxsize
        self.batch_size = 1
        self.epochs = 1
        self.distributed = False
        self.mixed_precision = False
        self.rank = 0
        self.total_batches = 0
        self.progress = False
        self.logger = None
        self.seed = 42
        self.profile = False
        self.launch_time = 0
        self.epochs_to_validate = []

    def setup_experiment(self, config):
        """
        Configure the experiment for training

        :param config: Dictionary containing the configuration parameters

            - distributed: Whether or not to use Pytorch Distributed training
            - backend: Pytorch Distributed backend ("nccl", "gloo")
                    Default: nccl
            - world_size: Total number of processes participating
            - rank: Rank of the current process
            - data: Dataset path
            - train_dir: Dataset training data relative path
            - batch_size: Training batch size
            - val_dir: Dataset validation data relative path
            - val_batch_size: Validation batch size
            - workers: how many data loading processes to use
            - num_classes: Limit the dataset size to the given number of classes
            - model_class: Model class. Must inherit from "torch.nn.Module"
            - model_args: model model class arguments passed to the constructor
            - init_batch_norm: Whether or not to Initialize running batch norm
                               mean to 0.
            - optimizer_class: Optimizer class.
                               Must inherit from "torch.optim.Optimizer"
            - optimizer_args: Optimizer class class arguments passed to the
                              constructor
            - batch_norm_weight_decay: Whether or not to apply weight decay to
                                       batch norm modules parameters
            - lr_scheduler_class: Learning rate scheduler class.
                                 Must inherit from "_LRScheduler"
            - lr_scheduler_args: Learning rate scheduler class class arguments
                                 passed to the constructor
            - loss_function: Loss function. See "torch.nn.functional"
            - local_dir: Results path
            - epochs: Number of epochs to train
            - batches_in_epoch: Number of batches per epoch.
                                Useful for debugging
            - progress: Show progress during training
            - profile: Whether or not to enable torch.autograd.profiler.profile
                       during training
            - name: Experiment name. Used as logger name
            - log_level: Python Logging level
            - log_format: Python Logging format
            - seed: the seed to be used for pytorch, python, and numpy
            - mixed_precision: Whether or not to enable apex mixed precision
            - mixed_precision_args: apex mixed precision arguments.
                                    See "amp.initialize"
            - create_train_dataloader: Optional user defined function to create
                                       the training data loader. See below for
                                       input params.
            - create_validation_dataloader: Optional user defined function to create
                                            the validation data loader. See below for
                                            input params.
            - train_model_func: Optional user defined function to train the model,
                                expected to behave similarly to `train_model`
                                in terms of input parameters and return values
            - evaluate_model_func: Optional user defined function to validate the model
                                   expected to behave similarly to `evaluate_model`
                                   in terms of input parameters and return values
            - init_hooks: list of hooks (functions) to call on the model
                          just following its initialization
            - post_epoch_hooks: list of hooks (functions) to call on the model
                                following each epoch of training
            - checkpoint_file: if not None, will start from this model. The model
                               must have the same model_args and model_class as the
                               current experiment.
            - checkpoint_at_init: boolean argument for whether to create a checkpoint
                                  of the initialized model. this differs from
                                  `checkpoint_at_start` for which the checkpoint occurs
                                  after the first epoch of training as opposed to
                                  before it
            - epochs_to_validate: list of epochs to run validate(). A -1 asks
                                  to run validate before any training occurs.
                                  Default: last three epochs.
            - launch_time: time the config was created (via time.time). Used to report
                           wall clock time until the first batch is done.
                           Default: time.time() in this setup_experiment().
        """
        # Configure logging related stuff
        log_format = config.get("log_format", logging.BASIC_FORMAT)
        log_level = getattr(logging, config.get("log_level", "INFO").upper())
        console = logging.StreamHandler()
        console.setFormatter(logging.Formatter(log_format))
        self.logger = logging.getLogger(config.get("name",
                                                   type(self).__name__))
        self.logger.setLevel(log_level)
        self.logger.addHandler(console)
        self.progress = config.get("progress", False)
        self.launch_time = config.get("launch_time", time.time())

        # Configure seed
        self.seed = config.get("seed", self.seed)
        set_random_seed(self.seed, False)

        # Configure distribute pytorch
        self.distributed = config.get("distributed", False)
        self.rank = config.get("rank", 0)
        if self.distributed:
            dist_url = config.get("dist_url", "tcp://127.0.0.1:54321")
            backend = config.get("backend", "nccl")
            world_size = config.get("world_size", 1)
            dist.init_process_group(
                backend=backend,
                init_method=dist_url,
                rank=self.rank,
                world_size=world_size,
            )
            # Only enable logs from first process
            self.logger.disabled = self.rank != 0
            self.progress = self.progress and self.rank == 0

        # Configure model
        model_class = config["model_class"]
        model_args = config.get("model_args", {})
        init_batch_norm = config.get("init_batch_norm", False)
        init_hooks = config.get("init_hooks", None)
        self.model = create_model(model_class=model_class,
                                  model_args=model_args,
                                  init_batch_norm=init_batch_norm,
                                  device=self.device,
                                  init_hooks=init_hooks,
                                  checkpoint_file=config.get(
                                      "checkpoint_file", None))
        if self.rank == 0:
            self.logger.debug(self.model)
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug("Params total/nnz %s / %s = %s ", params_sparse,
                              nonzero_params_sparse2,
                              float(nonzero_params_sparse2) / params_sparse)

        # Configure optimizer
        optimizer_class = config.get("optimizer_class", torch.optim.SGD)
        optimizer_args = config.get("optimizer_args", {})
        batch_norm_weight_decay = config.get("batch_norm_weight_decay", True)
        self.optimizer = create_optimizer(
            model=self.model,
            optimizer_class=optimizer_class,
            optimizer_args=optimizer_args,
            batch_norm_weight_decay=batch_norm_weight_decay,
        )

        # Validate mixed precision requirements
        self.mixed_precision = config.get("mixed_precision", False)
        if self.mixed_precision and amp is None:
            self.mixed_precision = False
            self.logger.error(
                "Mixed precision requires NVIDA APEX."
                "Please install apex from https://www.github.com/nvidia/apex"
                "Disabling mixed precision training.")

        # Configure mixed precision training
        if self.mixed_precision:
            amp_args = config.get("mixed_precision_args", {})
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, **amp_args)
            self.logger.info("Using mixed precision")

        # Apply DistributedDataParallel after all other model mutations
        if self.distributed:
            self.model = DistributedDataParallel(self.model)
        else:
            self.model = DataParallel(self.model)

        self.loss_function = config.get("loss_function",
                                        torch.nn.functional.cross_entropy)

        # Configure data loaders
        self.epochs = config.get("epochs", 1)
        self.batches_in_epoch = config.get("batches_in_epoch", sys.maxsize)
        self.epochs_to_validate = config.get(
            "epochs_to_validate", range(self.epochs - 3, self.epochs + 1))
        workers = config.get("workers", 0)
        data_dir = config["data"]
        train_dir = config.get("train_dir", "train")
        num_classes = config.get("num_classes", 1000)

        # Get initial batch size
        self.batch_size = config.get("batch_size", 1)

        # CUDA runtime does not support the fork start method.
        # See https://pytorch.org/docs/stable/notes/multiprocessing.html
        if torch.cuda.is_available():
            multiprocessing.set_start_method("spawn")

        # Configure Training data loader
        self.create_train_dataloader = config.get("create_train_dataloader",
                                                  create_train_dataloader)
        self.train_loader = self.create_train_dataloader(
            data_dir=data_dir,
            train_dir=train_dir,
            batch_size=self.batch_size,
            workers=workers,
            distributed=self.distributed,
            num_classes=num_classes,
            use_auto_augment=config.get("use_auto_augment", False),
        )
        self.total_batches = len(self.train_loader)

        # Configure Validation data loader
        val_dir = config.get("val_dir", "val")
        val_batch_size = config.get("val_batch_size", self.batch_size)
        self.create_validation_dataloader = config.get(
            "create_validation_dataloader", create_validation_dataloader)
        self.val_loader = self.create_validation_dataloader(
            data_dir=data_dir,
            val_dir=val_dir,
            batch_size=val_batch_size,
            workers=workers,
            num_classes=num_classes,
        )

        # Configure learning rate scheduler
        lr_scheduler_class = config.get("lr_scheduler_class", None)
        if lr_scheduler_class is not None:
            lr_scheduler_args = config.get("lr_scheduler_args", {})
            self.logger.info("LR Scheduler args:")
            self.logger.info(pformat(lr_scheduler_args))
            self.logger.info("steps_per_epoch=%s", self.total_batches)
            self.lr_scheduler = create_lr_scheduler(
                optimizer=self.optimizer,
                lr_scheduler_class=lr_scheduler_class,
                lr_scheduler_args=lr_scheduler_args,
                steps_per_epoch=self.total_batches)

        # Only profile from rank 0
        self.profile = config.get("profile", False) and self.rank == 0

        # Set train and validate methods.
        self.train_model = config.get("train_model_func", train_model)
        self.evaluate_model = config.get("evaluate_model_func", evaluate_model)

        # Register post-epoch hooks. To be used as `self.model.apply(post_epoch_hook)`
        self.post_epoch_hooks = config.get("post_epoch_hooks", [])

    def validate(self, epoch, loader=None):
        if loader is None:
            loader = self.val_loader

        if epoch in self.epochs_to_validate:
            results = self.evaluate_model(
                model=self.model,
                loader=loader,
                device=self.device,
                criterion=self.loss_function,
                batches_in_epoch=self.batches_in_epoch,
            )
        else:
            results = {
                "total_correct": 0,
                "mean_loss": 0.0,
                "mean_accuracy": 0.0,
            }

        results.update(learning_rate=self.get_lr()[0], )
        self.logger.info(results)

        return results

    def train_epoch(self, epoch):
        with torch.autograd.profiler.profile(
                use_cuda=torch.cuda.is_available(),
                enabled=self.profile) as prof:
            self.train_model(
                model=self.model,
                loader=self.train_loader,
                optimizer=self.optimizer,
                device=self.device,
                criterion=self.loss_function,
                batches_in_epoch=self.batches_in_epoch,
                pre_batch_callback=functools.partial(self.pre_batch,
                                                     epoch=epoch),
                post_batch_callback=functools.partial(self.post_batch,
                                                      epoch=epoch),
            )
        if self.profile and prof is not None:
            self.logger.info(
                prof.key_averages().table(sort_by="self_cpu_time_total"))

    def run_epoch(self, epoch):
        if -1 in self.epochs_to_validate and epoch == 0:
            self.logger.debug("Validating before any training:")
            self.validate(epoch=-1)
        self.pre_epoch(epoch)
        self.train_epoch(epoch)
        self.post_epoch(epoch)
        t1 = time.time()
        ret = self.validate(epoch)

        if self.rank == 0:
            self.logger.debug("validate time: %s", time.time() - t1)
            self.logger.debug("---------- End of run epoch ------------")
            self.logger.debug("")

        return ret

    def pre_epoch(self, epoch):
        self.model.apply(update_boost_strength)
        if self.distributed:
            self.train_loader.sampler.set_epoch(epoch)

    def pre_batch(self, model, batch_idx, epoch):
        pass

    def post_batch(self, model, loss, batch_idx, epoch, num_images,
                   time_string):
        # Update 1cycle learning rate after every batch
        if isinstance(self.lr_scheduler, (OneCycleLR, ComposedLRScheduler)):
            self.lr_scheduler.step()

        if self.progress and epoch == 0 and batch_idx == 0:
            self.logger.info("Launch time to end of first batch: %s",
                             time.time() - self.launch_time)

        if self.progress and (batch_idx % 40) == 0:
            total_batches = self.total_batches
            current_batch = batch_idx
            if self.distributed:
                # Compute actual batch size from distributed sampler
                total_batches *= self.train_loader.sampler.num_replicas
                current_batch *= self.train_loader.sampler.num_replicas
            self.logger.debug(
                "End of batch for rank: %s. Epoch: %s, Batch: %s/%s, "
                "loss: %s, Learning rate: %s num_images: %s", self.rank, epoch,
                current_batch, total_batches, loss, self.get_lr(), num_images)
            self.logger.debug("Timing: %s", time_string)

    def post_epoch(self, epoch):
        count_nnz = self.logger.isEnabledFor(logging.DEBUG) and self.rank == 0
        if count_nnz:
            params_sparse, nonzero_params_sparse1 = count_nonzero_params(
                self.model)

        self.model.apply(rezero_weights)
        if self.post_epoch_hooks:
            for hook in self.post_epoch_hooks:
                self.model.apply(hook)

        if count_nnz:
            params_sparse, nonzero_params_sparse2 = count_nonzero_params(
                self.model)
            self.logger.debug(
                "Params total/nnz before/nnz after %s %s / %s = %s",
                params_sparse, nonzero_params_sparse1, nonzero_params_sparse2,
                float(nonzero_params_sparse2) / params_sparse)

        self.logger.debug("End of epoch %s LR/weight decay before step: %s/%s",
                          epoch, self.get_lr(), self.get_weight_decay())

        # Update learning rate
        if not isinstance(self.lr_scheduler,
                          (OneCycleLR, ComposedLRScheduler)):
            self.lr_scheduler.step()

        self.logger.debug("End of epoch %s LR/weight decay after step: %s/%s",
                          epoch, self.get_lr(), self.get_weight_decay())

    def get_state(self):
        """
        Get experiment serialized state as a dictionary of  byte arrays
        :return: dictionary with "model", "optimizer" and "lr_scheduler" states
        """
        # Save state into a byte array to avoid ray's GPU serialization issues
        # See https://github.com/ray-project/ray/issues/5519
        state = {}
        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, self.model.module.state_dict())
            state["model"] = buffer.getvalue()

        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, self.optimizer.state_dict())
            state["optimizer"] = buffer.getvalue()

        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, self.lr_scheduler.state_dict())
            state["lr_scheduler"] = buffer.getvalue()

        if self.mixed_precision:
            with io.BytesIO() as buffer:
                serialize_state_dict(buffer, amp.state_dict())
                state["amp"] = buffer.getvalue()

        return state

    def set_state(self, state):
        """
        Restore the experiment from the state returned by `get_state`
        :param state: dictionary with "model", "optimizer", "lr_scheduler", and "amp"
                      states
        """
        if "model" in state:
            with io.BytesIO(state["model"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.model.module.load_state_dict(state_dict)

        if "optimizer" in state:
            with io.BytesIO(state["optimizer"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.optimizer.load_state_dict(state_dict)

        if "lr_scheduler" in state:
            with io.BytesIO(state["lr_scheduler"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.lr_scheduler.load_state_dict(state_dict)

        if "amp" in state and amp is not None:
            with io.BytesIO(state["amp"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            amp.load_state_dict(state_dict)

    def stop_experiment(self):
        if self.distributed:
            dist.destroy_process_group()

    def get_lr(self):
        """
        Returns the current learning rate
        :return: list of learning rates used by the optimizer
        """
        return [p["lr"] for p in self.optimizer.param_groups]

    def get_weight_decay(self):
        """
        Returns the current weight decay
        :return: list of weight decays used by the optimizer
        """
        return [p["weight_decay"] for p in self.optimizer.param_groups]

    def get_node_ip(self):
        """Returns the IP address of the current ray node."""
        return ray.services.get_node_ip_address()
示例#10
0
class model():
    def __int__(self):
        pass

    def train(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = '2'
        device_ids = [0]
        self.classifier = classifier()
        get_paprams(self.classifier)
        get_paprams(self.classifier.base)
        # data_set_eval = my_dataset(eval=True)
        # data_set = my_dataset_10s()
        # data_set_test = my_dataset_10s()
        data_set = my_dataset_10s_smote()
        data_set_test = my_dataset_10s_smote(test=True,
                                             all_data=data_set.all_data,
                                             all_label=data_set.all_label,
                                             index_=data_set.index)
        # data_set_eval = my_dataset_10s(eval=True)
        # data_set_combine = my_dataset(combine=True)
        batch = 300
        # totoal_epoch = 2000
        # print('batch:{}'.format(batch))
        # self.evaluation = evaluation
        data_loader = DataLoader(data_set,
                                 batch,
                                 shuffle=True,
                                 collate_fn=detection_collate)
        data_loader_test = DataLoader(data_set_test,
                                      batch,
                                      False,
                                      collate_fn=detection_collate)
        # data_loader_eval = DataLoader(data_set_eval, batch, False, collate_fn=detection_collate)
        self.classifier = self.classifier.cuda()
        self.classifier = DataParallel(self.classifier, device_ids=device_ids)
        optim = Adadelta(self.classifier.parameters(),
                         0.1,
                         0.9,
                         weight_decay=1e-5)

        self.cretion = smooth_focal_weight()

        # data_loader_combine = DataLoader(data_set_combine, 225, False, collate_fn=detection_collate)
        self.classifier.apply(weights_init)
        start_time = time.time()
        count = 0
        epoch = -1
        while 1:
            epoch += 1
            runing_losss = [0] * 5
            for data in data_loader:
                loss = [0] * 5
                y = data[1].cuda()
                x = data[0].cuda()
                optim.zero_grad()

                weight = torch.Tensor([0.5, 2, 0.5, 2]).cuda()

                predict = self.classifier(x)

                for i in range(5):
                    loss[i] = self.cretion(predict[i], y, weight)
                tmp = sum(loss)

                tmp.backward()
                # loss5.backward()
                optim.step()
                for i in range(5):
                    runing_losss[i] += (tmp.item())

                count += 1
                # torch.cuda.empty_cache()
            end_time = time.time()
            print("epoch:{a}: loss:{b} spend_time:{c} time:{d}".format(
                a=epoch,
                b=sum(runing_losss),
                c=int(end_time - start_time),
                d=time.asctime()))
            start_time = end_time

            save(self.classifier.module.base.state_dict(),
                 str(epoch) + 'base_c1.p')
            save(self.classifier.module.state_dict(), str(epoch) + 'base_c1.p')

            self.classifier.eval()

            self.evaluation(self.classifier, data_loader_test, epoch)
            # self.evaluation(self.classifier, data_loader, epoch)

            self.classifier.train()
            if epoch % 10 == 0:
                adjust_learning_rate(optim, 0.9, epoch, totoal_epoch, 0.1)

    def evaluation(self, classifier, data_loader_test, epoch):
        # classifier.eval()
        all_predict = [[], [], [], [], []]
        all_ground = []
        with torch.no_grad():
            for data in data_loader_test:
                y = data[1].cuda()
                x = data[0].cuda()
                predict_list = classifier(x)
                for i in range(5):
                    predict, index = torch.max(predict_list[i], 1)
                    all_predict[i].extend(index.tolist())

                all_ground.extend(y.tolist())

        print("Accuracy:{}".format(
            metrics.accuracy_score(all_ground, all_predict[0])))
        print('precesion:{}'.format(
            metrics.precision_score(all_ground, all_predict[0], average=None)))
        print('recall:{}'.format(
            metrics.recall_score(all_ground, all_predict[0], average=None)))
        print('f-score:{}'.format(
            metrics.f1_score(all_ground, all_predict[0], average=None)))
        print("{}".format(metrics.confusion_matrix(all_ground,
                                                   all_predict[0])))
        for i in range(5):
            tmp = metrics.accuracy_score(all_ground, all_predict[i])
            tmp2 = metrics.precision_score(all_ground,
                                           all_predict[i],
                                           average=None)
            tmp3 = metrics.recall_score(all_ground,
                                        all_predict[i],
                                        average=None)
            tmp4 = metrics.f1_score(all_ground, all_predict[i], average=None)
            print("Accuracy:{}".format(tmp))
            print('precesion:{}'.format(tmp2))
            print('recall:{}'.format(tmp3))
            print('f-score:{}'.format(tmp4))

    def test(self):
        self.classifier = classifier()
        self.classifier = self.classifier.cuda()
        data_set = my_dataset_10s_smote(test=True)
        data_loader_test = DataLoader(data_set,
                                      300,
                                      False,
                                      collate_fn=detection_collate)
        all_predict = []
        all_ground = []
        self.classifier.eval()
        self.classifier.base.eval()
        total = 0
        with torch.no_grad():
            for data in data_loader_test:
                y = data[1].cuda()
                x = data[0].cuda()
                every_len = data[2]
                max_len = data[3]
                predict = self.classifier(x, every_len, max_len)[0]
                # predict = F.softmax(predict, 1)
                predict, index = torch.max(predict, 1)
                # total += predict.sum().item()
                ########
                ###
                all_predict.extend(list(index.cpu().numpy()))
                all_ground.extend(list(y.cpu().numpy()))
        # print(sum(all_predict))
        # print(sum(all_ground))
        print(metrics.precision_score(all_ground, all_predict, average=None))
        print(metrics.recall_score(all_ground, all_predict, average=None))
        print(metrics.f1_score(all_ground, all_predict, average=None))
        print(metrics.confusion_matrix(all_ground, all_predict))
示例#11
0
class GAUGAN(object):
    def __init__(self, args):
        self.device = args.device
        self.img_path = args.img_path
        self.seg_path = args.seg_path
        self.batch_size = args.batch_size
        self.num_workers = args.num_workers
        self.lr_G = args.lr_G
        self.lr_D = args.lr_D
        self.beta_1 = args.beta_1
        self.beta_2 = args.beta_2
        self.total_step = args.total_step
        self.n_critic = args.n_critic
        self.n_save = args.n_save
        self.ckpt_dir = args.ckpt_dir
        self.lambda_fm = args.lambda_fm
        self.lambda_kl = args.lambda_kl
        self.lambda_vgg = args.lambda_vgg
        self.grid_n_row = args.grid_n_row
        self.n_save_image = args.n_save_image
        self.img_dir = args.img_dir
        self.GAN_D_loss_type = args.GAN_D_loss_type
        self.GAN_G_loss_type = args.GAN_G_loss_type
        self.save_Dis = args.save_Dis
        self.start_annealing_epoch = args.start_annealing_epoch
        self.end_annealing_epoch = args.end_annealing_epoch
        self.end_lr = args.end_lr
        self.test_img_path = args.test_img_path
        self.test_seg_path = args.test_seg_path
        self.test_batch_size = args.test_batch_size
        self.use_vgg = args.use_vgg
        self.n_summary = args.n_summary
        self.sum_dir = args.sum_dir
        self.seg_channel = args.seg_channel

    def load_dataset(self):
        self.transform_img = transforms.Compose([
            transforms.Resize(size=256, interpolation=0),
            transforms.ToTensor(),
        ])

        self.dataset = customDataset(
            origin_path=self.img_path,
            segmen_path=self.seg_path,
            transform=self.transform_img,
        )

        self.loader = DataLoader(
            dataset=self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
        )

        self.test_dataset = customDataset(
            origin_path=self.test_img_path,
            segmen_path=self.test_seg_path,
            transform=self.transform_img,
        )

        self.test_loader = DataLoader(dataset=self.test_dataset,
                                      batch_size=self.test_batch_size)

    def build_model(self):
        ########## Networks ##########
        self.enc = DP(image_encoder()).to(self.device)
        self.gen = DP(generator(seg_channel=self.seg_channel)).to(self.device)
        self.disORI = DP(discriminator(down_scale=1)).to(self.device)
        self.disHAL = DP(discriminator(down_scale=2)).to(self.device)
        # self.disQUA = DP(discriminator(down_scale=4)).to(self.device)

        ########## Init Networks with Xavier normal ##########
        self.enc.apply(networks._init_weights)
        self.gen.apply(networks._init_weights)
        self.disORI.apply(networks._init_weights)
        self.disHAL.apply(networks._init_weights)
        # self.disQUA.apply(networks._init_weights)

        ########## Loss ##########
        self.KLloss = KL_loss(self.device)
        self.GAN_D_loss = GAN_D_loss(self.device,
                                     GAN_D_loss_type=self.GAN_D_loss_type)
        self.GAN_G_loss = GAN_G_loss(self.device,
                                     GAN_G_loss_type=self.GAN_G_loss_type)
        self.FMloss = FM_loss(self.device)
        if self.use_vgg:
            self.VGGloss = VGG_loss(self.device)

        ########## Optimizer ##########
        self.G_optim = torch.optim.Adam(list(self.gen.parameters()) +
                                        list(self.enc.parameters()),
                                        lr=self.lr_G,
                                        betas=(self.beta_1, self.beta_2))
        self.G_lambda = lambda epoch: max(self.end_lr, (
            epoch - self.start_annealing_epoch) * (self.lr_G - self.end_lr) /
                                          (self.start_annealing_epoch - self.
                                           end_annealing_epoch) + self.lr_G)
        self.G_optim_sch = torch.optim.lr_scheduler.LambdaLR(
            self.G_optim, lr_lambda=self.G_lambda)

        self.D_optim = torch.optim.Adam(
            list(self.disORI.parameters()) + list(self.disHAL.parameters())
            # + list(self.disQUA.parameters())
            ,
            lr=self.lr_D,
            betas=(self.beta_1, self.beta_2))
        self.D_lambda = lambda epoch: max(self.end_lr, (
            epoch - self.start_annealing_epoch) * (self.lr_D - self.end_lr) /
                                          (self.start_annealing_epoch - self.
                                           end_annealing_epoch) + self.lr_D)
        self.D_optim_sch = torch.optim.lr_scheduler.LambdaLR(
            self.D_optim, lr_lambda=self.D_lambda)

    def train(self):
        self.enc.train()
        self.gen.train()
        self.disORI.train()
        self.disHAL.train()
        # self.disQUA.train()

        summary = SummaryWriter(self.sum_dir)

        data_loader = iter(self.loader)

        test_data_loader = iter(self.test_loader)
        test_real_image, test_seg = next(test_data_loader)
        test_real_image, test_seg = test_real_image.to(
            self.device), test_seg.to(self.device)

        pbar = tqdm(range(self.total_step))
        epoch = 0

        for step in pbar:
            try:
                real_image, seg = next(data_loader)

            except:
                data_loader = iter(self.loader)
                real_image, seg = next(data_loader)
                epoch += 1
                if epoch >= self.start_annealing_epoch:
                    self.G_optim_sch.step()
                    self.D_optim_sch.step()

            real_image, seg = real_image.to(self.device), seg.to(self.device)

            ##### train Discriminator #####
            self.D_optim.zero_grad()

            mu, squ_sigma, z = self.enc(real_image)
            fake_image = self.gen(z, seg)

            list_real_ORI, list_fake_ORI = self.disORI(real_image, fake_image,
                                                       seg)
            list_real_HAL, list_fake_HAL = self.disHAL(real_image, fake_image,
                                                       seg)
            # list_real_QUA, list_fake_QUA = self.disQUA(real_image, fake_image, seg)

            listed_D_real = [list_real_ORI[-1], list_real_HAL[-1]]
            listed_D_fake = [list_fake_ORI[-1], list_fake_HAL[-1]]
            # listed_D_real = [list_real_ORI[-1], list_real_HAL[-1], list_real_QUA[-1]]
            # listed_D_fake = [list_fake_ORI[-1], list_fake_HAL[-1], list_fake_QUA[-1]]

            GAN_D_real_loss, GAN_D_fake_loss, GAN_D_loss = self.GAN_D_loss(
                listed_D_real, listed_D_fake)

            GAN_D = GAN_D_loss
            GAN_D.backward()
            self.D_optim.step()

            ##### train Generator #####
            if step % self.n_critic == 0:
                self.G_optim.zero_grad()

                mu, squ_sigma, z = self.enc(real_image)
                fake_image = self.gen(z, seg)

                list_real_ORI, list_fake_ORI = self.disORI(
                    real_image, fake_image, seg)
                list_real_HAL, list_fake_HAL = self.disHAL(
                    real_image, fake_image, seg)
                # list_real_QUA, list_fake_QUA = self.disQUA(real_image, fake_image, seg)

                listed_FMs_D_real = [list_real_ORI[:-1], list_real_HAL[:-1]]
                listed_FMs_D_fake = [list_fake_ORI[:-1], list_fake_HAL[:-1]]
                # listed_FMs_D_real = [list_real_ORI[:-1], list_real_HAL[:-1], list_real_QUA[:-1]]
                # listed_FMs_D_fake = [list_fake_ORI[:-1], list_fake_HAL[:-1], list_fake_QUA[:-1]]

                listed_D_real = [list_real_ORI[-1], list_real_HAL[-1]]
                listed_D_fake = [list_fake_ORI[-1], list_fake_HAL[-1]]
                # listed_D_real = [list_real_ORI[-1], list_real_HAL[-1], list_real_QUA[-1]]
                # listed_D_fake = [list_fake_ORI[-1], list_fake_HAL[-1], list_fake_QUA[-1]]

                GAN_G_loss = self.GAN_G_loss(listed_D_fake)
                fm_loss = self.FMloss(listed_FMs_D_real, listed_FMs_D_fake)
                kl_loss = self.KLloss(mu, squ_sigma)
                if self.use_vgg:
                    vgg_loss = self.VGGloss(real_image, fake_image)
                else:
                    vgg_loss = 0

                GAN_G = GAN_G_loss + self.lambda_fm * fm_loss + self.lambda_kl * kl_loss
                if self.use_vgg:
                    GAN_G = GAN_G_loss + self.lambda_fm * fm_loss + self.lambda_kl * kl_loss + self.lambda_vgg * vgg_loss
                GAN_G.backward()
                self.G_optim.step()

            if step % self.n_save == 0:
                self.save_ckpt(self.ckpt_dir, step, epoch, self.save_Dis)

            if step % self.n_save_image == 0:
                fake_image_ref, fake_image_latent = self.eval(
                    test_seg, test_real_image)
                self.save_img(self.img_dir, test_seg, test_real_image,
                              fake_image_ref, fake_image_latent, step)
                print()

            if step % self.n_summary == 0:
                self.writeLogs(summary, step, GAN_D_real_loss, GAN_D_fake_loss,
                               GAN_D, GAN_G, GAN_G_loss,
                               self.lambda_fm * fm_loss,
                               self.lambda_kl * kl_loss,
                               self.lambda_vgg * vgg_loss)

            state_msg = (
                'Epo : {} ; '.format(epoch) +
                'D_real : {:0.3f} ; D_fake : {:0.3f} ; '.format(
                    GAN_D_real_loss, GAN_D_fake_loss) +
                'total_D : {:0.3f} ; total_G : {:0.3f} ; '.format(
                    GAN_D, GAN_G) + 'G : {:0.3f} ; FM : {:0.3f} ; '.format(
                        GAN_G_loss, self.lambda_fm * fm_loss) +
                'kl : {:0.3f} ; vgg : {:0.3f} ;'.format(
                    self.lambda_kl * kl_loss, self.lambda_vgg * vgg_loss))

            pbar.set_description(state_msg)

    def writeLogs(self, summary, step, D_real, D_fake, D, G, GAN_G, fm, kl,
                  vgg):
        summary.add_scalar('D_real', D_real.item(), step)
        summary.add_scalar('D_fake', D_fake.item(), step)
        summary.add_scalar('D', D.item(), step)
        summary.add_scalar('G', G.item(), step)
        summary.add_scalar('GAN_G', GAN_G.item(), step)
        summary.add_scalar('fm', fm.item(), step)
        summary.add_scalar('kl', kl.item(), step)
        summary.add_scalar('vgg', vgg.item(), step)

    def save_ckpt(self, dir, step, epoch, save_Dis):
        model_dict = {}
        model_dict['enc'] = self.enc.state_dict()
        model_dict['gen'] = self.gen.state_dict()
        if save_Dis:
            model_dict['disORI'] = self.disORI.state_dict()
            model_dict['disHAL'] = self.disHAL.state_dict()
            # model_dict['disQUA'] = self.disQUA.state_dict()
        torch.save(model_dict, os.path.join(dir,
                                            f'{str(step+1).zfill(7)}.ckpt'))

    def save_img(self, dir, seg, real_img, fake_img_ref, fake_img_latent,
                 step):
        image_grid = torch.unsqueeze(seg[0], dim=0)
        for i in range(self.grid_n_row):
            if i == 0:
                image_grid = torch.cat(
                    (image_grid, torch.unsqueeze(real_img[i], dim=0)), dim=0)
                image_grid = torch.cat(
                    (image_grid, torch.unsqueeze(fake_img_ref[i], dim=0)),
                    dim=0)
                image_grid = torch.cat(
                    (image_grid, torch.unsqueeze(fake_img_latent[i], dim=0)),
                    dim=0)
            else:
                image_grid = torch.cat(
                    (image_grid, torch.unsqueeze(seg[i], dim=0)), dim=0)
                image_grid = torch.cat(
                    (image_grid, torch.unsqueeze(real_img[i], dim=0)), dim=0)
                image_grid = torch.cat(
                    (image_grid, torch.unsqueeze(fake_img_ref[i], dim=0)),
                    dim=0)
                image_grid = torch.cat(
                    (image_grid, torch.unsqueeze(fake_img_latent[i], dim=0)),
                    dim=0)
            if i == self.test_batch_size - 1:
                break

        image_grid = make_grid(image_grid, nrow=4, padding=2)
        save_image(image_grid, os.path.join(dir,
                                            f'{str(step+1).zfill(7)}.jpg'))

    def load(self, dir, ckpt_step, save_Dis):
        model_dict = torch.load(
            os.path.join(dir, f'{str(ckpt_step).zfill(7)}.ckpt'))
        self.enc.load_state_dict(model_dict['enc'])
        self.gen.load_state_dict(model_dict['gen'])
        if save_Dis:
            self.disORI.load_state_dict(model_dict['disORI'])
            self.disHAL.load_state_dict(model_dict['disHAL'])
            # self.disQUA.load_state_dict(model_dict['disQUA'])

    def eval(self, seg, real_image):
        _, _, z = self.enc(real_image)
        fake_image_ref = self.gen(z, seg)

        latent = torch.randn(self.test_batch_size, 256).to(self.device)
        fake_image_latent = self.gen(latent, seg)

        return fake_image_ref, fake_image_latent

    def test(self):
        pass
示例#12
0
def main():

    torch.manual_seed(114514)
    torch.cuda.manual_seed_all(114514)

    model = to_cuda(NetBasic(top_softmax))
    model = DataParallel(model, device_ids=[0, 1, 2, 3])
    model.apply(weights_init)
    criterion = nn.CrossEntropyLoss()  # ce_loss
    optimizer = Adam(model.parameters(), lr=args.lr)

    print "executing fold %d" % args.fold
    # import dataset

    train_dataset = ExclusionDataset(luna_dir,
                                     data_index_dir,
                                     fold=args.fold,
                                     phase='train')
    X_train, y_train = load_data(train_dataset, nodule_dir)
    unlabeled_dataset = ExclusionDataset(luna_dir,
                                         data_index_dir,
                                         fold=args.fold,
                                         phase='unlabeled')
    X_ul = load_data(unlabeled_dataset, nodule_dir)
    print "Labeled training samples: %d" % len(train_dataset)
    if args.semi_spv == 0:
        print "supervised mission"
    else:
        print "semi-supervised mission"
        print "Unlabeled training samples: %d" % len(unlabeled_dataset)
    # data argumentation
    if args.argument != 0:
        X_train, y_train = argumentation(X_train, y_train, args.argument)

    # parameters for training
    batch_size = args.batch_size
    print
    print
    ce_loss_list = []
    vat_loss_list = []
    for epoch in range(args.epochs):
        print "epoch: %d" % (epoch + 1)

        # epoch decay settings
        if epoch <= args.epochs * 0.5:
            decayed_lr = args.lr
        elif epoch <= args.epochs * 0.8:
            decayed_lr = 0.1 * args.lr
        else:
            decayed_lr = ((args.epochs - epoch) *
                          (0.1 * args.lr)) / (args.epochs -
                                              (0.8 * args.epochs))
        optimizer.lr = decayed_lr
        optimizer.betas = (0.5, 0.999)
        print "contains %d iterations." % num_iter_per_epoch
        for i in tqdm(range(num_iter_per_epoch)):
            # training in batches
            batch_indices = torch.LongTensor(
                np.random.choice(len(train_dataset), batch_size,
                                 replace=False))
            x_64 = X_train[batch_indices]
            y = y_train[batch_indices]
            x_32 = extract_half(x_64)

            # semi-supervised, we used same batch-size for both labeled and unlabeled
            if args.semi_spv == 1:
                batch_indices_unlabeled = torch.LongTensor(
                    np.random.choice(len(unlabeled_dataset),
                                     batch_size,
                                     replace=False))
                ul_x_64 = X_ul[batch_indices_unlabeled]
                ul_x_32 = extract_half(ul_x_64)
                v_loss, ce_loss = train_semi(model.train(),
                                             Variable(to_cuda(x_32)),
                                             Variable(to_cuda(x_64)),
                                             Variable(to_cuda(y)),
                                             Variable(to_cuda(ul_x_32)),
                                             Variable(to_cuda(ul_x_64)),
                                             optimizer,
                                             criterion,
                                             epsilon=args.epsilon,
                                             lamb=args.lamb)
                if i == num_iter_per_epoch - 1:
                    print "epoch %d: " % (
                        epoch + 1), "vat_loss: ", v_loss, "ce_loss: ", ce_loss
                    ce_loss_list.append(ce_loss)
                    vat_loss_list.append(v_loss)

            # supervised with cross-entropy loss
            else:
                sv_loss = train_supervise(model.train(),
                                          Variable(to_cuda(x_32)),
                                          Variable(to_cuda(x_64)),
                                          Variable(to_cuda(y)), optimizer,
                                          criterion)
                if i == num_iter_per_epoch - 1:
                    print "epoch %d: " % (epoch + 1), "sv_loss", sv_loss
                    ce_loss_list.append(sv_loss)

    # saving model
    print "saving model..."
    state_dict = model.module.state_dict()
    for key in state_dict.keys():
        state_dict[key] = state_dict[key].cpu()
    if args.semi_spv == 1:
        save_dir = os.path.join(args.save_dir, 'fold%d' % args.fold,
                                'semi_spv')
    else:
        save_dir = os.path.join(args.save_dir, 'fold%d' % args.fold,
                                'supervise')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(
        {
            'save_dir': args.save_dir,
            'state_dict': state_dict,
            'args': args
        }, os.path.join(save_dir, 'model.ckpt'))

    # Saving loss results
    print "Saving loss results"
    if args.semi_spv == 1:
        ce_loss_list = np.asarray(ce_loss_list, dtype=np.float64)
        vat_loss_list = np.asarray(vat_loss_list, dtype=np.float64)
        np.save(os.path.join(save_dir, 'vat_loss.npy'), vat_loss_list)
        np.save(os.path.join(save_dir, 'ce_loss.npy'), ce_loss_list)
    else:
        ce_loss_list = np.asarray(ce_loss_list, dtype=np.float64)
        np.save(os.path.join(save_dir, 'sv_loss.npy'), ce_loss_list)

    # Generating test results one by one
    print "Evaluation step..."
    test_dataset = ExclusionDataset(luna_dir,
                                    data_index_dir,
                                    fold=args.fold,
                                    phase='test')
    print "Testing samples: %d" % len(test_dataset)
    X_test, y_test, uids, center = load_data(test_dataset, nodule_dir)
    y_test = y_test.numpy()
    series_uid_list = []
    coord_x_list = []
    coord_y_list = []
    coord_z_list = []
    proba_pos_list = []
    proba_neg_list = []
    label_list = []
    print "Testing..."
    for i in tqdm(range(len(test_dataset))):
        prob_neg, prob_pos = evaluate(
            model.eval(), Variable(to_cuda(extract_half(X_test[[i]]))),
            Variable(to_cuda(X_test[[i]])))
        series_uid_list.append(uids[i])
        coord_x_list.append(center[i][0])
        coord_y_list.append(center[i][1])
        coord_z_list.append(center[i][2])
        proba_neg_list.append(prob_neg)
        proba_pos_list.append(prob_pos)
        label_list.append(y_test[i])
    print "Finished evaluation step, generating evaluation files.."
    # Saving results
    data_frame = DataFrame({
        'seriesuid': series_uid_list,
        'coordX': coord_x_list,
        'coordY': coord_y_list,
        'coordZ': coord_z_list,
        'proba_neg': proba_neg_list,
        'proba_pos': proba_pos_list,
        'label': label_list
    })
    data_frame.to_csv(os.path.join(save_dir, 'eval_results.csv'),
                      index=False,
                      sep=',')
示例#13
0
class Solver(object):
    def __init__(self, opt):
        self.opt = opt
        self.name = opt.name
        self.output_dir = Path(opt.output_dir) / self.name
        self.preddump_dir = self.output_dir / 'preddump'
        self.preddump_dir.mkdir(parents=True, exist_ok=True)
        self.sample_dir = self.output_dir / 'sample'
        self.sample_dir.mkdir(parents=True, exist_ok=True)
        self.log_dir = self.output_dir / 'tensorboard'
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.ckpt_dir = self.output_dir / 'ckpt'
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)

        self.global_iter = 0
        self.init_loss_functions()
        self.init_colorize()
        self.init_models_optimizers_data()
        self.load_states()

        if not opt.inference:
            self.writer = SummaryWriter(self.log_dir,
                                        purge_step=self.global_iter)

    def init_models_optimizers_data(self):
        opt = self.opt
        device = opt.device
        self.encoder, self.decoder = get_autoencoder(opt)
        self.frame_predictor = DeterministicConvLSTM(opt.g_dim + opt.z_dim,
                                                     opt.g_dim, opt.rnn_size,
                                                     opt.predictor_rnn_layers,
                                                     opt.batch_size, opt.M)
        self.posterior = GaussianConvLSTM(opt.g_dim, opt.z_dim, opt.rnn_size,
                                          opt.posterior_rnn_layers,
                                          opt.batch_size, opt.M)
        self.prior = GaussianConvLSTM(opt.g_dim, opt.z_dim, opt.rnn_size,
                                      opt.prior_rnn_layers, opt.batch_size,
                                      opt.M)
        if not opt.deepspeed:
            self.encoder = self.encoder.to(device)
            self.decoder = self.decoder.to(device)
            self.frame_predictor = self.frame_predictor.to(device)
            self.posterior = self.posterior.to(device)
            self.prior = self.prior.to(device)

        self.frame_predictor_optimizer = optim.Adam(
            self.frame_predictor.parameters(),
            lr=opt.lr,
            betas=(opt.beta1, 0.999))
        self.posterior_optimizer = optim.Adam(self.posterior.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, 0.999))
        self.prior_optimizer = optim.Adam(self.prior.parameters(),
                                          lr=opt.lr,
                                          betas=(opt.beta1, 0.999))
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.frame_predictor.apply(init_weights)
        self.posterior.apply(init_weights)
        self.prior.apply(init_weights)
        self.encoder.apply(init_weights)
        self.decoder.apply(init_weights)

        encoder_params = filter(lambda p: p.requires_grad,
                                self.encoder.parameters())
        decoder_params = filter(lambda p: p.requires_grad,
                                self.decoder.parameters())
        frame_predictor_params = filter(lambda p: p.requires_grad,
                                        self.frame_predictor.parameters())
        posterior_params = filter(lambda p: p.requires_grad,
                                  self.posterior.parameters())
        prior_params = filter(lambda p: p.requires_grad,
                              self.prior.parameters())

        if opt.load_dp_ckpt:
            self.load_dp_ckpt()
        if opt.load_ds_ckpt:
            self.load_ds_ckpt()

        train_data, test_data = load_dataset(opt)
        if opt.inference:
            # use pytorch loaders for both train/test loader in inference mode
            train_loader = DataLoader(train_data,
                                      num_workers=opt.data_threads,
                                      batch_size=opt.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      pin_memory=True)
            test_loader = DataLoader(test_data,
                                     num_workers=opt.data_threads,
                                     batch_size=1,
                                     shuffle=False,
                                     drop_last=False,
                                     pin_memory=True)
        elif not opt.inference and not opt.deepspeed:
            # use pytorch loaders for both train/test loader when not using deepspeed
            train_loader = DataLoader(train_data,
                                      num_workers=opt.data_threads,
                                      batch_size=opt.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      pin_memory=True)
            test_loader = DataLoader(test_data,
                                     num_workers=opt.data_threads,
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     drop_last=True,
                                     pin_memory=True)
        elif not opt.inference and opt.deepspeed:
            # use deepspeed train loader when training with deepspeed.
            # use pytorch test loader when testing
            test_loader = DataLoader(test_data,
                                     num_workers=opt.data_threads,
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     drop_last=True,
                                     pin_memory=True)

        if opt.deepspeed:
            if not opt.inference:
                self.encoder, self.encoder_optimizer, train_loader, _ = ds.initialize(
                    opt,
                    model=self.encoder,
                    model_parameters=encoder_params,
                    dist_init_required=True,
                    training_data=train_data)
            else:
                self.encoder, self.encoder_optimizer, _, _ = ds.initialize(
                    opt,
                    model=self.encoder,
                    model_parameters=encoder_params,
                    dist_init_required=True)
            self.decoder, self.decoder_optimizer, _, _ = ds.initialize(
                opt,
                model=self.decoder,
                model_parameters=decoder_params,
                dist_init_required=False)
            self.frame_predictor, self.frame_predictor_optimizer, _, _ = ds.initialize(
                opt,
                model=self.frame_predictor,
                model_parameters=frame_predictor_params,
                dist_init_required=False)
            self.posterior, self.posterior_optimizer, _, _ = ds.initialize(
                opt,
                model=self.posterior,
                model_parameters=posterior_params,
                dist_init_required=False)
            self.prior, self.prior_optimizer, _, _ = ds.initialize(
                opt,
                model=self.prior,
                model_parameters=prior_params,
                dist_init_required=False)

            def normalize_data_ds(opt, sequence):
                data, data_path = sequence
                data.transpose_(0, 1)
                return data.to(self.encoder.local_rank), data_path

            def get_batch(loader):
                while True:
                    for sequence in loader:
                        batch = normalize_data_ds(opt, sequence)
                        yield batch

            def get_dump_batch(loader):
                for sequence in loader:
                    batch = normalize_data_ds(opt, sequence)
                    yield batch
        else:
            self.encoder = DataParallel(self.encoder)
            self.decoder = DataParallel(self.decoder)
            self.frame_predictor = DataParallel(self.frame_predictor)
            self.posterior = DataParallel(self.posterior)
            self.prior = DataParallel(self.prior)

            if opt.device == 'cuda':
                dtype = torch.cuda.FloatTensor
            else:
                dtype = torch.FloatTensor

            def get_batch(loader):
                while True:
                    for sequence in loader:
                        batch = normalize_data_dp(opt, dtype, sequence)
                        yield batch

            def get_dump_batch(loader):
                for sequence in loader:
                    batch = normalize_data_dp(opt, dtype, sequence)
                    yield batch

        self.training_batch_generator = get_batch(train_loader)
        if opt.inference:
            self.testing_batch_generator = get_dump_batch(test_loader)
        else:
            self.testing_batch_generator = get_batch(test_loader)

    def init_colorize(self):
        if self.opt.dataset in [
                'KITTI_64', 'KITTI_128', 'KITTI_256', 'Cityscapes_128x256'
        ]:
            self.opt.n_class = n_class = 19
            self.pallette = return_colormap('KITTI').byte().numpy().reshape(
                -1).tolist()
            self.colorize = Colorize(n_class, return_colormap('KITTI'))
        elif self.opt.dataset in ['Pose_64', 'Pose_128']:
            self.opt.n_class = n_class = 25
            self.pallette = return_colormap(
                N=25).byte().numpy().reshape(-1).tolist()
            self.colorize = Colorize(n_class, return_colormap(N=25))
        else:
            raise ValueError()

    def load_states(self, idx=None):
        if self.opt.deepspeed and not self.opt.load_dp_ckpt:
            if idx is None:
                idx = 'last'
            savedir = self.ckpt_dir / str(idx)
            if savedir is not None:
                try:
                    _, _ = self.encoder.load_checkpoint(savedir, 'encoder')
                    _, _ = self.decoder.load_checkpoint(savedir, 'decoder')
                    _, _ = self.frame_predictor.load_checkpoint(
                        savedir, 'frame_predictor')
                    _, _ = self.posterior.load_checkpoint(savedir, 'posterior')
                    _, _ = self.prior.load_checkpoint(savedir, 'prior')
                    self.global_iter = _['step']
                except:
                    printstr = 'ckpt is not found at: %s' % savedir
                    print_rank_0(printstr)
                    return
                else:
                    printstr = 'ckpt is loaded from: %s' % savedir
                    print_rank_0(printstr)

        if not self.opt.deepspeed and not self.opt.load_ds_ckpt:
            idx = 'last.pth' if idx is None else '%d.pth' % idx
            path = self.ckpt_dir / idx
            try:
                ckpt = torch.load(path)

                self.global_iter = ckpt['global_iter']

                self.frame_predictor.load_state_dict(ckpt['frame_predictor'])
                self.posterior.load_state_dict(ckpt['posterior'])
                self.prior.load_state_dict(ckpt['prior'])
                self.encoder.load_state_dict(ckpt['encoder'])
                self.decoder.load_state_dict(ckpt['decoder'])
            except:
                printstr = 'failed to load ckpt from: %s' % path
                print(printstr)
            else:
                printstr = 'ckpt is loaded from: %s' % path
                print(printstr)

    def dump_states(self, idx=None):
        if self.opt.deepspeed:
            if idx is None:
                idx = 'last'
            savedir = self.ckpt_dir / str(idx)
            client_state = {'step': self.global_iter, 'opt': self.opt}
            self.encoder.save_checkpoint(savedir, 'encoder', client_state)
            self.decoder.save_checkpoint(savedir, 'decoder', client_state)
            self.frame_predictor.save_checkpoint(savedir, 'frame_predictor',
                                                 client_state)
            self.posterior.save_checkpoint(savedir, 'posterior', client_state)
            self.prior.save_checkpoint(savedir, 'prior', client_state)
        else:
            torch.save(
                {
                    'global_iter':
                    self.global_iter,
                    'encoder':
                    self.encoder.state_dict(),
                    'encoder_optimizer':
                    self.encoder_optimizer.state_dict(),
                    'decoder':
                    self.decoder.state_dict(),
                    'decoder_optimizer':
                    self.decoder_optimizer.state_dict(),
                    'frame_predictor':
                    self.frame_predictor.state_dict(),
                    'frame_predictor_optimizer':
                    self.frame_predictor_optimizer.state_dict(),
                    'posterior':
                    self.posterior.state_dict(),
                    'posterior_optimizer':
                    self.posterior_optimizer.state_dict(),
                    'prior':
                    self.prior.state_dict(),
                    'prior_optimizer':
                    self.prior_optimizer.state_dict(),
                    'opt':
                    self.opt
                }, '%s/%s.pth' % (self.ckpt_dir, idx))

    def load_dp_ckpt(self, idx=None):
        idx = 'last.pth' if idx is None else '%d.pth' % idx
        path = self.ckpt_dir / idx
        try:
            ckpt = torch.load(path)
        except FileNotFoundError as e:
            print(e)
            pass
        else:
            self.global_iter = ckpt['global_iter']

            self.encoder = DataParallel(self.encoder)
            self.decoder = DataParallel(self.decoder)
            self.frame_predictor = DataParallel(self.frame_predictor)
            self.posterior = DataParallel(self.posterior)
            self.prior = DataParallel(self.prior)

            self.frame_predictor.load_state_dict(ckpt['frame_predictor'])
            self.posterior.load_state_dict(ckpt['posterior'])
            self.prior.load_state_dict(ckpt['prior'])
            self.encoder.load_state_dict(ckpt['encoder'])
            self.decoder.load_state_dict(ckpt['decoder'])

            self.encoder = self.encoder.module
            self.decoder = self.decoder.module
            self.frame_predictor = self.frame_predictor.module
            self.posterior = self.posterior.module
            self.prior = self.prior.module

            printstr = 'ckpt is loaded from: %s' % path
            print(printstr)

    def load_ds_ckpt(self, idx=None):
        idx = 'last' if idx is None else str(idx)
        path = str(self.ckpt_dir / idx / '%s/mp_rank_00_model_states.pt')

        try:
            encoder_ckpt = torch.load(path % 'encoder')
            decoder_ckpt = torch.load(path % 'decoder')
            frame_predictor_ckpt = torch.load(path % 'frame_predictor')
            posterior_ckpt = torch.load(path % 'posterior')
            prior_ckpt = torch.load(path % 'prior')
        except FileNotFoundError as e:
            print(e)
            pass
        else:
            self.encoder.load_state_dict(encoder_ckpt['module'])
            self.decoder.load_state_dict(decoder_ckpt['module'])
            self.frame_predictor.load_state_dict(
                frame_predictor_ckpt['module'])
            self.posterior.load_state_dict(posterior_ckpt['module'])
            self.prior.load_state_dict(prior_ckpt['module'])

            self.encoder_optimizer.load_state_dict(encoder_ckpt['optimizer'])
            self.decoder_optimizer.load_state_dict(decoder_ckpt['optimizer'])
            self.frame_predictor_optimizer.load_state_dict(
                frame_predictor_ckpt['optimizer'])
            self.posterior_optimizer.load_state_dict(
                posterior_ckpt['optimizer'])
            self.prior_optimizer.load_state_dict(prior_ckpt['optimizer'])
            self.global_iter = encoder_ckpt['step']
            printstr = 'ckpt is loaded from: %s' % path
            print(printstr)

    def init_loss_functions(self):
        self.kl_criterion = kl_criterion
        self.nll = nn.NLLLoss()

    def train(self, x):
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        self.frame_predictor.zero_grad()
        self.posterior.zero_grad()
        self.prior.zero_grad()

        kld = 0
        nll = 0
        prior_hidden = None
        posterior_hidden = None
        frame_predictor_hidden = None
        for i in range(1, self.opt.n_past + self.opt.n_future):
            x_in = x[i - 1]
            x_target = x[i]

            h = self.encoder(x_in)
            h_target = self.encoder(x_target)[0]

            if self.opt.last_frame_skip or i < self.opt.n_past + 1:
                h, skip = h
            else:
                h = h[0]

            z_t, mu, logvar, posterior_hidden = self.posterior(
                h_target, posterior_hidden)
            _, mu_p, logvar_p, prior_hidden = self.prior(h, prior_hidden)
            h_pred, frame_predictor_hidden = self.frame_predictor(
                torch.cat([h, z_t], 1), frame_predictor_hidden)

            x_pred = self.decoder([h_pred, skip])
            nll += self.nll(x_pred, x_target.squeeze(1).long())
            kld += self.kl_criterion(mu, logvar, mu_p, logvar_p)

        loss = nll + kld * self.opt.beta
        loss.backward()

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        self.frame_predictor_optimizer.step()
        self.posterior_optimizer.step()
        self.prior_optimizer.step()

        output = dict()
        normalizer = self.opt.n_past + self.opt.n_future
        output['nll'] = nll.item() / normalizer
        output['kld'] = kld.item() / normalizer

        return output

    @torch.no_grad()
    def validate(self, x):
        kld = 0
        nll = 0
        prior_hidden = None
        posterior_hidden = None
        frame_predictor_hidden = None
        for i in range(1, self.opt.n_past + self.opt.n_future):
            x_in = x[i - 1]
            x_target = x[i]

            h = self.encoder(x_in)
            h_target = self.encoder(x_target)[0]

            if self.opt.last_frame_skip or i < self.opt.n_past + 1:
                h, skip = h
            else:
                h = h[0]

            z_t, mu, logvar, posterior_hidden = self.posterior(
                h_target, posterior_hidden)
            _, mu_p, logvar_p, prior_hidden = self.prior(h, prior_hidden)
            h_pred, frame_predictor_hidden = self.frame_predictor(
                torch.cat([h, z_t], 1), frame_predictor_hidden)

            x_pred = self.decoder([h_pred, skip])
            nll += self.nll(x_pred, x_target.squeeze(1).long())
            kld += self.kl_criterion(mu, logvar, mu_p, logvar_p)

        output = dict()
        normalizer = self.opt.n_past + self.opt.n_future
        output['nll'] = nll.item() / normalizer
        output['kld'] = kld.item() / normalizer

        return output

    def solve(self):
        pbar = tqdm(range(self.global_iter, self.opt.max_iter))
        start_time = time.time()
        for _ in pbar:
            self.global_iter += 1

            self.frame_predictor.train()
            self.posterior.train()
            self.prior.train()
            self.encoder.train()
            self.decoder.train()

            x, _ = next(self.training_batch_generator)

            # train
            output = self.train(x)
            nll = output['nll']
            kld = output['kld']

            if self.global_iter % self.opt.log_ckpt_iter == 0:
                # save the model
                self.dump_states(self.global_iter)
                self.dump_states('last')

            if time.time() - start_time > self.opt.log_ckpt_sec:
                # save the model
                self.dump_states('last')
                start_time = time.time()

            if self.global_iter % self.opt.print_iter == 0:
                printstr = '[%02d] nll: %.5f | kld loss: %.5f' % (
                    self.global_iter,
                    nll,
                    kld,
                )
                #tprint_rank_0(pbar, printstr)
                pbar.set_description(printstr)

            if self.global_iter % self.opt.log_line_iter == 0:
                self.writer.add_scalar('train_nll',
                                       nll,
                                       global_step=self.global_iter)
                self.writer.add_scalar('train_kld',
                                       kld,
                                       global_step=self.global_iter)

            if self.global_iter % self.opt.log_img_iter == 0:
                # plot some stuff
                self.frame_predictor.eval()
                self.posterior.eval()
                self.prior.eval()
                self.encoder.eval()
                self.decoder.eval()

                x, _ = next(self.testing_batch_generator)
                if torch.distributed.is_initialized():
                    if torch.distributed.get_rank() == 0:
                        plot(x, self)
                else:
                    plot(x, self)

            if self.global_iter % self.opt.validate_iter == 0:
                nll = 0
                kld = 0
                nvalsample = 0
                for _ in range(100):
                    x, _ = next(self.testing_batch_generator)
                    output = self.validate(x)
                    nll += output['nll']
                    kld += output['kld']
                    nvalsample += x[0].size(0)

                nll /= nvalsample
                kld /= nvalsample
                self.writer.add_scalar('test_nll',
                                       nll,
                                       global_step=self.global_iter)
                self.writer.add_scalar('test_kld',
                                       kld,
                                       global_step=self.global_iter)
        pbar.close()

    @torch.no_grad()
    def inference(self):
        topil = transforms.ToPILImage()

        n_prediction = self.opt.n_prediction

        self.frame_predictor.eval()
        self.posterior.eval()
        self.prior.eval()
        self.encoder.eval()
        self.decoder.eval()
        for batch_idx, (x_seqs, paths) in tqdm(
                enumerate(self.testing_batch_generator)):

            # When unrolling step is beyond the number of grund-truth data
            for _ in range(self.opt.n_past + self.opt.n_eval - len(x_seqs)):
                x_seqs.append(x_seqs[-1])
                path_parts = paths[-1][0].split('/')
                name = path_parts[-1]
                if 'KITTI' in self.opt.dataset:
                    newname = '%s_%010d.png' % (
                        '_'.join(name.strip('.png').split('_')[:-1]),
                        int(name.strip('.png').split('_')[-1]) +
                        self.opt.frame_sampling_rate)
                elif 'Cityscapes' in self.opt.dataset:
                    new_idx = '%06d' % (int(name.split('_')[-2]) +
                                        self.opt.frame_sampling_rate)
                    parts = name.split('_')
                    parts[-2] = new_idx
                    newname = '_'.join(parts)
                elif 'Pose' in self.opt.dataset:
                    newname = newname = 'frame%06d_IUV.png' % (
                        int(name.strip('.png').strip('frame').split('_')[0]) +
                        self.opt.frame_sampling_rate)
                newpath = ['/'.join(path_parts[:-1] + [newname])]
                paths.append(newpath)

            x_pred_seqs = []
            for s in range(n_prediction):
                skip = None
                prior_hidden = None
                posterior_hidden = None
                frame_predictor_hidden = None
                x_in = x_seqs[0]
                x_pred_seq = [x_in.data.cpu().byte()]
                for i in range(1, self.opt.n_past + self.opt.n_eval):

                    h = self.encoder(x_in)
                    if self.opt.last_frame_skip or i < self.opt.n_past + 1:
                        h, skip = h
                    else:
                        h = h[0]

                    if i < self.opt.n_past:
                        x_target = x_seqs[i]
                        h_target = self.encoder(x_target)[0]
                        z_t, _, _, posterior_hidden = self.posterior(
                            h_target, posterior_hidden)
                        _, _, _, prior_hidden = self.prior(h, prior_hidden)
                        _, frame_predictor_hidden = self.frame_predictor(
                            torch.cat([h, z_t], 1), frame_predictor_hidden)
                        x_in = x_target
                    else:
                        z_t, _, _, prior_hidden = self.prior(h, prior_hidden)
                        h_pred, frame_predictor_hidden = self.frame_predictor(
                            torch.cat([h, z_t], 1), frame_predictor_hidden)
                        x_in = self.decoder([h_pred,
                                             skip]).argmax(dim=1, keepdim=True)

                    x_pred_seq.append(x_in.data.cpu().byte())

                x_pred_seq = torch.stack(x_pred_seq, dim=1)
                x_pred_seqs.append(x_pred_seq)

            x_seqs = torch.cat(
                x_seqs).data.cpu().byte()  # (n_past+n_eval, 1, H, W)
            x_pred_seqs = torch.cat(x_pred_seqs).transpose(
                0, 1)  # (n_past+n_eval, n_prediction, 1, H, W)

            for x_gt, x_preds, path in zip(x_seqs, x_pred_seqs, paths):
                path = Path(path[0])
                if 'KITTI' in self.opt.dataset:
                    maskpath = self.preddump_dir.joinpath(
                        str(self.global_iter), 'batch_%05d' % (batch_idx + 1),
                        'sample_%05d' % (0), Path(path.parts[3], path.name))
                elif 'Cityscapes' in self.opt.dataset:
                    maskpath = self.preddump_dir.joinpath(
                        str(self.global_iter), 'batch_%05d' % (batch_idx + 1),
                        'sample_%05d' % (0), Path(path.parts[3], path.name))
                elif 'Pose' in self.opt.dataset:
                    vidname, clipname = path.parts[-3:-1]
                    maskpath = self.preddump_dir.joinpath(
                        str(self.global_iter), 'batch_%05d' % (batch_idx + 1),
                        'sample_%05d' % (0), vidname + '_' + clipname,
                        path.name)

                maskpath.parent.mkdir(exist_ok=True, parents=True)
                x_gt = topil(x_gt).convert('P', colors=self.opt.n_class)
                x_gt.putpalette(self.pallette)
                x_gt.save(maskpath)

                for num_x_pred, x_pred in enumerate(x_preds):
                    if 'KITTI' in self.opt.dataset:
                        maskpath = self.preddump_dir.joinpath(
                            str(self.global_iter),
                            'batch_%05d' % (batch_idx + 1),
                            'sample_%05d' % (num_x_pred + 1),
                            Path(path.parts[3], path.name))
                    elif 'Cityscapes' in self.opt.dataset:
                        maskpath = self.preddump_dir.joinpath(
                            str(self.global_iter),
                            'batch_%05d' % (batch_idx + 1),
                            'sample_%05d' % (num_x_pred + 1),
                            Path(path.parts[3], path.name))
                    elif 'Pose' in self.opt.dataset:
                        maskpath = self.preddump_dir.joinpath(
                            str(self.global_iter),
                            'batch_%05d' % (batch_idx + 1),
                            'sample_%05d' % (num_x_pred + 1),
                            vidname + '_' + clipname, path.name)
                    maskpath.parent.mkdir(exist_ok=True, parents=True)
                    x_pred = topil(x_pred).convert('P',
                                                   colors=self.opt.n_class)
                    x_pred.putpalette(self.pallette)
                    x_pred.save(maskpath)