Ejemplo n.º 1
0
    def _save_training_details(self, reward, size, reward_net_key=None):
        if not os.path.exists(self.folder):
            os.makedirs(self.folder)

        with open(os.path.join(self.folder, "training.json"), "wt") as file:
            reward_type = "env" if reward is None else "net"
            with io.StringIO() as out, redirect_stdout(out):
                summary(self, torch.zeros(get_input_shape()).to(self.current_device()), show_input=True)
                summary(self, torch.zeros(get_input_shape()).to(self.current_device()), show_input=False)
                net_summary = out.getvalue()
            print(net_summary)
            # self.name = os.path.basename(os.path.normpath(self.folder))
            j = {"name": self.name, "type": str(type(self)), "str": str(self).replace("\n", ""), "reward_type": reward_type,
                 "size": size, "max_episodes": self.max_episodes, "optimizer": str(self.optimizer),
                 "summary": net_summary}

            if hasattr(self, "scheduler"):
                j["scheduler"] = str(self.scheduler.__class__.__name__)
                if hasattr(self, "scheduler_kwargs"):
                    j["scheduler_kwargs"] = self.scheduler_kwargs

            if reward_type == "net":
                j["reward_net_key"] = reward_net_key
                j["reward_net_details"] = str(reward)

            json.dump(j, file, indent=True)
Ejemplo n.º 2
0
    def save_training_details(self, batch_size, num_subtrajectories, subtrajectory_length, use_also_complete_trajectories, train_games):
        print("saving details")
        if not os.path.exists(self.folder):
            os.makedirs(self.folder)
            'created dir "' + self.folder + '"'

        with open(os.path.join(self.folder, "training.json"), "wt") as file:
            with io.StringIO() as out, redirect_stdout(out):
                summary(self, torch.zeros(get_input_shape()).to(self.current_device()), show_input=True)
                summary(self, torch.zeros(get_input_shape()).to(self.current_device()), show_input=False)
                net_summary = out.getvalue()
            print(net_summary)
            j = {"type": str(type(self)), "str": str(self).replace("\n", ""), "optimizer": str(self.optimizer),
                 "penalty_rewards": self.lambda_abs_rewards, "batch_size": batch_size,
                 "num_subtrajectories": num_subtrajectories, "subtrajectory_length": subtrajectory_length,
                 "use_also_complete_trajectories": use_also_complete_trajectories, "summary": net_summary,
                 "max_epochs": self.max_epochs}
            if train_games is not None:
                j["games"] = train_games
            json.dump(j, file, indent=True)
            print('details saved on ' + os.path.join(self.folder, "training.json"))
Ejemplo n.º 3
0
def train_model(data_set_identifier, train_file, val_file, learning_rate, minibatch_size, name):
    set_experiment_id(data_set_identifier, learning_rate, minibatch_size, name)

    train_loader = contruct_dataloader_from_disk(train_file, minibatch_size, use_evolutionary=True)
    validation_loader = contruct_dataloader_from_disk(val_file, minibatch_size, use_evolutionary=True)
    validation_dataset_size = validation_loader.dataset.__len__()
    train_dataset_size = train_loader.dataset.__len__()



    embedding_size = 21
    if configs.run_params["use_evolutionary"]:
        embedding_size = 42


    #Load in existing model if given as argument
    if args.model is not None:
        model_path = "output/models/" + args.model + ".model"
        model = load_model_from_disk(model_path, use_gpu)
    else:
    #else construct new model from config file
        model = construct_model(configs.model_params, embedding_size, use_gpu,minibatch_size)
    
    #optimizer parameters
    betas = tuple(configs.run_params["betas"])
    weight_decay = configs.run_params["weight_decay"]
    angle_lr = configs.run_params["angles_lr"]

    if configs.model_params['architecture'] == 'cnn_angles':
        optimizer = optim.Adam(model.parameters(), betas=betas, lr=learning_rate, weight_decay=weight_decay)
    else:
        optimizer = optim.Adam([
            {'params' : model.model.parameters(), 'lr':learning_rate},
            {'params' : model.soft_to_angle.parameters(), 'lr':angle_lr}], betas=betas, weight_decay=weight_decay)
    
    #print number of trainable parameters
    print_number_of_parameters(model)
    #For creating a summary table of the model (does not work on ExampleModel!)
    if configs.run_params["print_model_summary"]:
        if configs.model_params["architecture"] != 'rnn':
            summary(model, configs.run_params["max_sequence_length"], 2)
        else:
            write_out("DETAILED MODEL SUMMARY IS NOT SUPPORTED FOR RNN MODELS")
    
    if use_gpu:
        model = model.cuda()

    # TODO: is soft_to_angle.parameters() included here?

    sample_num = list()
    train_loss_values = list()
    validation_loss_values = list()
    rmsd_avg_values = list()
    drmsd_avg_values = list()
    break_point_values = list()

    breakpoints = configs.run_params['breakpoints']
    best_model_loss = 1e20
    best_model_train_loss = 1e20
    best_model_minibatch_time = None
    best_model_path = None
    stopping_condition_met = False
    minibatches_proccesed = 0

    loss_atoms = configs.run_params["loss_atoms"]
    start_time = time.time()
    max_time = configs.run_params["max_time"]
    C_epochs = configs.run_params["c_epochs"] # TODO: Change to parameter
    C_batch_updates = C_epochs

    while not stopping_condition_met:
        optimizer.zero_grad()
        model.zero_grad()
        loss_tracker = np.zeros(0)
        start_time_n_minibatches = time.time()
        for minibatch_id, training_minibatch in enumerate(train_loader, 0):
            minibatches_proccesed += 1
            training_minibatch = list(training_minibatch)
            primary_sequence, tertiary_positions, mask, p_id = training_minibatch[:-1]
            # Update C
            C = 1.0 if minibatches_proccesed >= C_batch_updates else float(minibatches_proccesed) / C_batch_updates

            #One Hot encode amino string and concate PSSM values.
            amino_acids, batch_sizes = one_hot_encode(primary_sequence, 21, use_gpu)

            if configs.run_params["use_evolutionary"]:
                evolutionary = training_minibatch[-1]

                evolutionary, batch_sizes = torch.nn.utils.rnn.pad_packed_sequence(torch.nn.utils.rnn.pack_sequence(evolutionary))
                
                if use_gpu:
                    evolutionary = evolutionary.cuda()

                amino_acids = torch.cat((amino_acids, evolutionary.view(-1, len(batch_sizes) , 21)), 2)

            start_compute_loss = time.time()

            if configs.run_params["only_angular_loss"]:
                #raise NotImplementedError("Only_angular_loss function has not been implemented correctly yet.")
                loss = model.compute_angular_loss((amino_acids, batch_sizes), tertiary_positions, mask)
            else:
                loss = model.compute_loss((amino_acids, batch_sizes), tertiary_positions, mask, C=C, loss_atoms=loss_atoms)
            
            if C != 1:
                write_out("C:", C)
            write_out("Train loss:", float(loss))
            start_compute_grad = time.time()
            loss.backward()
            loss_tracker = np.append(loss_tracker, float(loss))
            end = time.time()
            write_out("Loss time:", start_compute_grad-start_compute_loss, "Grad time:", end-start_compute_grad)
            optimizer.step()
            optimizer.zero_grad()
            model.zero_grad()

            # for every eval_interval samples, plot performance on the validation set
            if minibatches_proccesed % configs.run_params["eval_interval"] == 0:
                model.eval()
                write_out("Testing model on validation set...")
                train_loss = loss_tracker.mean()
                loss_tracker = np.zeros(0)
                validation_loss, data_total, rmsd_avg, drmsd_avg = evaluate_model(validation_loader,
                     model, use_gpu, loss_atoms, configs.run_params["use_evolutionary"])
                prim = data_total[0][0]
                pos = data_total[0][1]
                pos_pred = data_total[0][3]
                mask = data_total[0][4]
                pos = apply_mask(pos, mask)
                angles_pred = data_total[0][2]

                angles_pred = apply_mask(angles_pred, mask, size=3)

                pos_pred = apply_mask(pos_pred, mask)
                prim = torch.masked_select(prim, mask)

                if use_gpu:
                    pos = pos.cuda()
                    pos_pred = pos_pred.cuda()

                angles = calculate_dihedral_angels(pos, use_gpu)
                #angles_pred = calculate_dihedral_angels(pos_pred, use_gpu)
                #angles_pred = data_total[0][2] # Use angles output from model - calculate_dihedral_angels(pos_pred, use_gpu)

                write_to_pdb(get_structure_from_angles(prim, angles), "test")
                write_to_pdb(get_structure_from_angles(prim, angles_pred), "test_pred")
                if validation_loss < best_model_loss:
                    best_model_loss = validation_loss
                    best_model_minibatch_time = minibatches_proccesed
                    best_model_path = write_model_to_disk(model)

                if train_loss < best_model_train_loss:
                    best_model_train_loss = train_loss
                    best_model_train_path = write_model_to_disk(model, model_type="train")

                write_out("Validation loss:", validation_loss, "Train loss:", train_loss)
                write_out("Best model so far (validation loss): ", best_model_loss, "at time", best_model_minibatch_time)
                write_out("Best model stored at " + best_model_path)
                write_out("Best model train stored at " + best_model_train_path)
                write_out("Minibatches processed:",minibatches_proccesed)

                end_time_n_minibatches = time.time()
                n_minibatches_time_used = end_time_n_minibatches - start_time_n_minibatches
                minibatches_left = configs.run_params["max_updates"] - minibatches_proccesed
                seconds_left = int(n_minibatches_time_used * (minibatches_left/configs.run_params["eval_interval"]))
                
                m, s = divmod(seconds_left, 60)
                h, m = divmod(m, 60)
                write_out("Estimated time until maximum number of updates:", '{:d}:{:02d}:{:02d}'.format(h, m, s) )
                sample_num.append(minibatches_proccesed)
                train_loss_values.append(train_loss)
                validation_loss_values.append(validation_loss)
                rmsd_avg_values.append(rmsd_avg)
                drmsd_avg_values.append(drmsd_avg)
                
                if breakpoints and minibatches_proccesed > breakpoints[0]:
                    break_point_values.append(drmsd_avg)
                    breakpoints = breakpoints[1:]

                data = {}
                data["pdb_data_pred"] = open("output/protein_test_pred.pdb","r").read()
                data["pdb_data_true"] = open("output/protein_test.pdb","r").read()
                data["validation_dataset_size"] = validation_dataset_size
                data["sample_num"] = sample_num
                data["train_loss_values"] = train_loss_values
                data["break_point_values"] = break_point_values
                data["validation_loss_values"] = validation_loss_values
                data["phi_actual"] = list([math.degrees(float(v)) for v in angles[1:,1]])
                data["psi_actual"] = list([math.degrees(float(v)) for v in angles[:-1,2]])
                data["phi_predicted"] = list([math.degrees(float(v)) for v in angles_pred[1:,1]])
                data["psi_predicted"] = list([math.degrees(float(v)) for v in angles_pred[:-1,2]])
                data["drmsd_avg"] = drmsd_avg_values
                data["rmsd_avg"] = rmsd_avg_values
                if not configs.run_params["hide_ui"]:
                    res = requests.post('http://localhost:5000/graph', json=data)
                    if res.ok:
                        print(res.json())
                
                # Save run data
                write_run_to_disk(data)

                #Check if maximum time is reached.
                start_time_n_minibatches = time.time()
                time_used = time.time() - start_time

                time_condition = (max_time is not None and time_used > max_time)
                max_update_condition = minibatches_proccesed >= configs.run_params["max_updates"]
                min_update_condition = (minibatches_proccesed > configs.run_params["min_updates"] and minibatches_proccesed > best_model_minibatch_time * 2)

                model.train()
                #Checking for stop conditions
                if time_condition or max_update_condition or min_update_condition:
                    stopping_condition_met = True
                    break
    write_out("Best validation model found after" , best_model_minibatch_time , "minibatches.")
    write_result_summary(best_model_loss)
    return best_model_path
Ejemplo n.º 4
0
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 2),
        )

    def forward(self, x): # x : [1, 3, 227, 227]
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

model = AlexNet()
print(model)
summary(model, torch.zeros((1, 3, 227, 227)))
Ejemplo n.º 5
0
                                                       str(index) + '.bias'])

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, fake, real):
        calc_index = [0, 3, 6, 8, 11, 13, 16, 18]
        if fake.size()[1] == 1:
            fake = fake.repeat(1, 3, 1, 1)
        if real.size()[1] == 1:
            real = real.repeat(1, 3, 1, 1)

        content_loss = 0
        for idx, sub_module in enumerate(self.vgg):
            fake = sub_module(fake)
            real = sub_module(real)
            if idx in calc_index:
                content_loss += ((fake - real)**2).mean()
        return content_loss / len(calc_index)

    # def forward(self, inp):
    #     return self.vgg(inp)


if __name__ == "__main__":
    from modelsummary import summary
    model = VGG_Encoder()
    input = torch.randn(4, 3, 256, 256).to('cuda')
    summary(model, input)
    output = model.forward(input)
    print(output.size())
Ejemplo n.º 6
0
    def __init__(self, decomposer, shader):
        super(Composer, self).__init__()

        self.decomposer = decomposer
        self.shader = shader

    def forward(self, inp):
        reflectance, shape, lights = self.decomposer(inp)
        shading = self.shader(shape, lights)
        shading_rep = shading.repeat(1, 3, 1, 1)
        reconstruction = reflectance * shading_rep
        return reconstruction, reflectance, shading, shape


if __name__ == "__main__":
    from modelsummary import summary
    shader = Shader(use_rrcnn_block=True, use_attention=True).to('cuda')
    lights = torch.randn(4, 4).to('cuda')
    shape = torch.randn(4, 3, 256, 256).to('cuda')
    datas = [shape, lights]
    summary(shader, *datas)
    x = shader(shape, lights)
    print(x.size())
    # decomposer = Decomposer(use_rcnn_block=True, use_attention=True).to('cuda')
    # inp = torch.randn(4, 3, 256, 256).to('cuda')
    # summary(decomposer, inp, show_hierarchical=True)
    # y = decomposer.forward(inp)
    # print(y[0].size())
    # print(y[1].size())
    # print(y[2].size())
Ejemplo n.º 7
0
        self.ConvBlock_f = ConvBlock_b(8, 3)

    def forward(self, x):
        if self.prev_t == None:
            self.prev_t = torch.zeros(x.size()).cuda(1)
        x = torch.cat((x, self.prev_t), dim=1)
        e1 = self.ConvBlock_b1(x)
        e2 = self.Maxpool(e1)
        e2 = self.ConvBlock_b2(e2)
        e3 = self.Maxpool(e2)
        e3 = self.ConvBlock_b3(e3)
        e4 = self.Maxpool(e3)
        e4 = e4.unsqueeze(0)
        d4, _ = self.SRU(e4)
        d3 = self.UpBlock_b3(d4[0].squeeze(1), e3)
        d2 = self.UpBlock_b2(d3, e2)
        d1 = self.UpBlock_b1(d2, e1)
        x = self.ConvBlock_f(d1)
        self.prev_t = x

        return x


if __name__ == "__main__":
    device = torch.device("cuda:0")
    model = R_UNet().cuda()
    summary(model, torch.zeros((1, 3, 512, 512)).cuda(), show_input=True)
    summary(model, torch.zeros((1, 3, 512, 512)).cuda(), show_input=False)

    y = model(torch.zeros((1, 3, 512, 512)).cuda())
            nn.Conv2d(in_channels, 6 * in_channels, kernel_size=5), nn.Tanh(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(6 * in_channels, 16 * in_channels, kernel_size=5),
            nn.Tanh(), nn.MaxPool2d(kernel_size=2))

        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5 * in_channels, 120 * in_channels),
            nn.Tanh(),
            nn.Linear(120 * in_channels, 84 * in_channels),
            nn.Tanh(),
            nn.Linear(84 * in_channels, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas


if __name__ == '__main__':

    # Device
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    model = LeNet5(10, True)
    model.to(device)
    summary(model,
            torch.ones(128, 1, 32, 32),
            batch_size=128,
            show_input=False)
Ejemplo n.º 9
0
        x = self.Maxpool(x)
        x = self.RRCNN2(x)

        x = self.Maxpool(x)
        x = self.RRCNN3(x)

        x = self.Maxpool(x)
        x = self.RRCNN4(x)

        x = self.Maxpool(x)
        x = self.RRCNN5(x)

        x = self.Maxpool(x)
        x = self.RRCNN6(x)

        x = self.Maxpool(x)
        x = self.RRCNN7(x)
        x = x.view(-1, 128 * 6 * 6)
        x = self.classifier(x)

        return x


if __name__ == "__main__":
    device = torch.device("cpu")
    model = Cls_R2U_Net()

    summary(model, torch.zeros((1, 3, 400, 400)), show_input=True)
    summary(model, torch.zeros((1, 3, 400, 400)), show_input=False)
Ejemplo n.º 10
0
    def train(self, batch_size, learning_rate, begin_epoch, epochs):

        train_data, test_data = self.loading_data_for_train(test_size=0.1)
        train_loader = Data.DataLoader(dataset=train_data,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=self.NUM_WORKERS)
        test_loader = Data.DataLoader(dataset=test_data,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      num_workers=self.NUM_WORKERS)

        model = self.load_model(model_file=None)
        summary(model,
                torch.zeros((1, 1, self.patch_size, self.patch_size)),
                show_input=False)

        criterion = nn.BCELoss(reduction='mean')
        if self.use_GPU:
            model.to(self.device)
            criterion.to(self.device)

        # optimzer4nn
        # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-4) #学习率为0.01的学习器
        # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=0.9,
        )
        # optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4, alpha=0.99, )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min',
            factor=0.5)  # mode为min,则loss不下降学习率乘以factor,max则反之

        # training and testing
        for epoch in range(begin_epoch, begin_epoch + epochs):
            print('Epoch {}/{}'.format(epoch + 1, begin_epoch + epochs))
            print('-' * 80)

            model.train()
            # 开始训练
            train_data_len = len(train_loader)
            total_loss = 0
            starttime = datetime.datetime.now()
            for step, (x, y) in enumerate(
                    train_loader
            ):  # 分配 batch data, normalize x when iterate train_loader
                b_x = Variable(x.to(self.device))
                b_y = Variable(y.to(self.device))

                output = model(b_x)  # cnn output
                loss = criterion(output, b_y)

                optimizer.zero_grad()  # clear gradients for this training step
                loss.backward()  # backpropagation, compute gradients
                optimizer.step()

                # 数据统计
                running_loss = loss.item()
                total_loss += running_loss

                if step % 50 == 0:
                    endtime = datetime.datetime.now()
                    remaining_time = (train_data_len - step) * (
                        endtime - starttime).seconds / (step + 1)
                    print(
                        '%d / %d ==> Loss: %.4f \t total loss %.4f: ,  remaining time: %d (s)'
                        % (step, train_data_len, running_loss, total_loss,
                           remaining_time))

            scheduler.step(total_loss)

            running_loss = 0.0
            running_corrects = 0
            model.eval()
            # 开始评估
            for x, y in test_loader:
                b_x = Variable(x.to(self.device))
                b_y = Variable(y.to(self.device))

                output = model(b_x)
                loss = criterion(output, b_y)

                running_loss += loss.item() * b_x.size(0)

            test_data_len = test_data.__len__()
            epoch_loss = running_loss / test_data_len

            save_filename = self.model_root + "/{}_cp-{:04d}-{:.4f}-{:.4f}.pth".format(
                self.model_name, epoch + 1, epoch_loss,
                total_loss / train_data_len)
            torch.save(model.state_dict(), save_filename)
            print("Saved ", save_filename)
Ejemplo n.º 11
0
                      round(256 * self.coef_width), bilinear)
        self.up4 = Up(round(128 * self.coef_width),
                      round(64 * self.coef_width),
                      round(128 * self.coef_width), bilinear)
        self.outc = OutConv(round(64 * self.coef_width), n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        # x5 = input_checker(x5, round(1024*self.coef_width))
        x = self.up1(x5, x4)
        # x = input_checker(x, round(512*self.coef_width))
        x = self.up2(x, x3)
        # x = input_checker(x, round(256*self.coef_width))
        x = self.up3(x, x2)
        # x = input_checker(x, round(128*self.coef_width))
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


if __name__ == "__main__":
    model = UNet()
    param = unet_params('unet-w0')

    summary(model, torch.zeros((1, 3, param[-2], param[-2])), show_input=True)
    summary(model, torch.zeros((1, 3, param[-2], param[-2])), show_input=False)
Ejemplo n.º 12
0
import torch
import torch.nn as nn
import torch.nn.functional as F

from modelsummary import summary


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# show input shape
summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True)

# show output shape
summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False)
Ejemplo n.º 13
0
    def train(self, data_filename, class_weight, batch_size, loss_weight, epochs):
        '''
        滤波器 训练
        :param data_filename: 样本文件名
        :param class_weight:类权重
        :param batch_size: batch size
        :param loss_weight: center loss的权重
        :param epochs: epochs
        :return:
        '''
        filename = "{}/data/{}".format(self._params.PROJECT_ROOT, data_filename)
        D = np.load(filename, allow_pickle=True)
        X = D['x']
        Y = D['y']
        # X = np.reshape(X, (-1,1,self.image_size, self.image_size))

        X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.05, )

        # train_data = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
        #                                             torch.from_numpy(y_train).long())
        # test_data = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(),
        #                                            torch.from_numpy(y_test).long())
        train_data = Sparse_Image_Dataset(X_train, y_train)
        test_data = Sparse_Image_Dataset(X_test, y_test)
        train_loader = Data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True,
                                       num_workers=self.NUM_WORKERS)
        test_loader = Data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False,
                                      num_workers=self.NUM_WORKERS)

        model = self.load_model(model_file=None)
        summary(model, torch.zeros((1, 1, self.image_size, self.image_size)), show_input=False)

        if class_weight is not None:
            class_weight = torch.FloatTensor(class_weight)

        classifi_loss = nn.CrossEntropyLoss(weight=class_weight)
        # center_loss = CenterLoss(self.num_classes, 2)
        center_loss = LGMLoss_v0(self.num_classes, 2, 1.00)
        if self.use_GPU:
            model.to(self.device)
            classifi_loss.to(self.device)
            center_loss.to(self.device)

        # optimzer4nn
        # classifi_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-4) #学习率为0.01的学习器
        # classifi_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        # classifi_optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9,)
        classifi_optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4, alpha=0.99, )
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)  # 每过30个epoch训练,学习率就乘gamma
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(classifi_optimizer, mode='min',
                                                               factor=0.5)  # mode为min,则loss不下降学习率乘以factor,max则反之
        # optimzer4center
        optimzer4center = torch.optim.SGD(center_loss.parameters(), lr=0.1)

        # training and testing
        for epoch in range(epochs):
            print('Epoch {}/{}'.format(epoch + 1, epochs))
            print('-' * 80)

            model.train()
            # 开始训练
            train_data_len = len(train_loader)
            total_loss = 0
            starttime = datetime.datetime.now()
            for step, (x, y) in enumerate(train_loader):  # 分配 batch data, normalize x when iterate train_loader
                b_x = Variable(x.to(self.device))
                b_y = Variable(y.to(self.device))

                output = model(b_x)  # cnn output

                # cross entropy loss + center loss
                # loss = classifi_loss(output, b_y) + loss_weight * center_loss(b_y, output)

                logits, mlogits, likelihood = center_loss(output, b_y)
                loss = classifi_loss(mlogits, b_y) + loss_weight * likelihood

                classifi_optimizer.zero_grad()  # clear gradients for this training step
                optimzer4center.zero_grad()
                loss.backward()  # backpropagation, compute gradients
                classifi_optimizer.step()
                optimzer4center.step()

                # 数据统计
                _, preds = torch.max(output, 1)

                running_loss = loss.item()
                running_corrects = torch.sum(preds == b_y.data)
                total_loss += running_loss

                if step % 50 == 0:
                    endtime = datetime.datetime.now()
                    remaining_time = (train_data_len - step)* (endtime - starttime).seconds / (step + 1)
                    print('%d / %d ==> Loss: %.4f | Acc: %.4f ,  remaining time: %d (s)'
                          % (step, train_data_len, running_loss, running_corrects.double()/b_x.size(0), remaining_time))

            scheduler.step(total_loss)

            running_loss=0.0
            running_corrects=0
            model.eval()
            # 开始评估
            for x, y in test_loader:
                b_x = Variable(x.to(self.device))
                b_y = Variable(y.to(self.device))

                output = model(b_x)
                loss = classifi_loss(output, b_y)

                _, preds = torch.max(output, 1)
                running_loss += loss.item() * b_x.size(0)
                running_corrects += torch.sum(preds == b_y.data)

            test_data_len = test_data.__len__()
            epoch_loss=running_loss / test_data_len
            epoch_acc=running_corrects.double() / test_data_len

            save_filename = self.model_root + "/{}_{}_cp-{:04d}-{:.4f}-{:.4f}.pth".format(
                self.model_name, self.patch_type,epoch+1, epoch_loss, epoch_acc)
            torch.save(model.state_dict(), save_filename)
            print("Saved ", save_filename)
        return