Exemplo n.º 1
0
def train_and_valid(learning_rate=lr,
                    weight_decay=weight_decay,
                    num_of_res=num_of_res_blocks,
                    if_bottleneck=bottle,
                    plot=True):
    """
    Train the model and run it on the valid set every epoch
    :param weight_decay: for L2 regularzition
    :param bottleneck:
    :param num_of_res:
    :param learning_rate: lr
    :param plot: draw the train/valid loss curve or not
    :return:
    """
    curr_lr = learning_rate
    # model define
    if NET_TYPE == 'res':
        if DATA_TYPE == 'hoa':
            block = ResBlock(128, 128, bottleneck=if_bottleneck)
        else:
            block = ResBlock(256, 256, bottleneck=if_bottleneck)
        model = ResNet(block,
                       numOfResBlock=num_of_res,
                       input_shape=input_shape,
                       data_type=DATA_TYPE).to(device)
    elif NET_TYPE == 'hoa':
        model = HOANet(input_shape=input_shape).to(device)
    else:
        raise RuntimeError('Unrecognized net type!')
    # print(model)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=weight_decay)

    # These parameters are for searching the best epoch to early stopping
    train_loss_curve, valid_loss_curve = [], []
    best_loss, avr_valid_loss = 10000.0, 0.0

    best_epoch = 0
    best_model = None  # the best parameters

    for epoch in range(num_epochs):
        # 每一轮的 训练集/验证集 误差
        train_loss_per_epoch, valid_loss_per_epoch = 0.0, 0.0
        train_step_cnt, valid_step_cnt = 0, 0

        train_data, valid_data = [], []
        # 进入训练模式
        model.train()
        random.shuffle(train_file_order)

        for idx, train_idx in enumerate(train_file_order):
            if len(train_data) < batch_size:
                train_data_temp = HOADataSet(
                    path=DATA_PATH + ('' if DATA_TYPE == 'hoa' else 'STFT/') +
                    'tr/',
                    index=train_idx + 1,
                    data_type=DATA_TYPE,
                    is_speech=SPEECH)
                if len(train_data) == 0:
                    train_data = train_data_temp
                else:
                    train_data += train_data_temp
                continue

            train_loader = data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

            for step, (examples, labels) in enumerate(train_loader):
                # if step == 1:
                #     break
                train_step_cnt += 1
                # print(train_step_cnt)
                examples = examples.float().to(device)
                labels = labels.float().to(device)
                outputs = model(examples)
                train_loss = criterion(outputs, labels)
                train_loss_per_epoch += train_loss.item()

                # Backward and optimize
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                logger.info(
                    "Epoch [{}/{}], Step {}, train Loss: {:.4f}".format(
                        epoch + 1, num_epochs, train_step_cnt,
                        train_loss.item()))

            train_data = HOADataSet(path=DATA_PATH +
                                    ('' if DATA_TYPE == 'hoa' else 'STFT/') +
                                    'tr/',
                                    index=train_idx + 1,
                                    data_type=DATA_TYPE,
                                    is_speech=SPEECH)

        if plot:
            train_loss_curve.append(train_loss_per_epoch / train_step_cnt)

        if running_lr and epoch > 1 and (epoch + 1) % 2 == 0:
            curr_lr = curr_lr * (1 - decay)
            update_lr(optimizer, curr_lr)

        # valid every epoch
        # 进入验证模式

        model.eval()
        with torch.no_grad():
            for idx, valid_idx in enumerate(valid_file_order):
                if len(valid_data) < batch_size:
                    valid_data_temp = HOADataSet(
                        path=DATA_PATH +
                        ('' if DATA_TYPE == 'hoa' else 'STFT/') + 'cv/',
                        index=valid_idx + 1,
                        data_type=DATA_TYPE,
                        is_speech=SPEECH)
                    if len(valid_data) == 0:
                        valid_data = valid_data_temp
                    else:
                        valid_data += valid_data_temp
                    continue

                valid_loader = data.DataLoader(dataset=valid_data,
                                               batch_size=batch_size,
                                               shuffle=True)

                for step, (examples, labels) in enumerate(valid_loader):
                    valid_step_cnt += 1
                    # print(valid_step_cnt)
                    examples = examples.float().to(device)
                    labels = labels.float().to(device)

                    outputs = model(examples)
                    valid_loss = criterion(outputs, labels)
                    valid_loss_per_epoch += valid_loss.item()

                    logger.info(
                        'The loss for the current batch:{}'.format(valid_loss))

                valid_data = HOADataSet(
                    path=DATA_PATH + ('' if DATA_TYPE == 'hoa' else 'STFT/') +
                    'cv/',
                    index=valid_idx + 1,
                    data_type=DATA_TYPE,
                    is_speech=SPEECH)

            avr_valid_loss = valid_loss_per_epoch / valid_step_cnt

            logger.info(
                'Epoch {}, the average loss on the valid set: {} '.format(
                    epoch, avr_valid_loss))

            valid_loss_curve.append(avr_valid_loss)
            if avr_valid_loss < best_loss:
                best_loss = avr_valid_loss
                best_epoch, best_model = epoch, model.state_dict()

    # end for loop of epoch
    torch.save(
        {
            'epoch': best_epoch,
            'state_dict': best_model,
            'loss': best_loss,
        }, './models/ckpoint_' + CUR_TASK + '_bot' + str(int(if_bottleneck)) +
        '_lr' + str(learning_rate) + '_wd' + str(weight_decay) + '_#res' +
        str(num_of_res) + '.tar')

    logger.info('best epoch:{}, valid loss:{}'.format(best_epoch, best_loss))
    if plot:
        x = np.arange(num_epochs)
        fig, ax = plt.subplots(1, 1)
        ax.plot(x, train_loss_curve, 'b', label='Train Loss')
        ax.plot(x, valid_loss_curve, 'r', label='Valid Loss')
        plt.legend(loc='upper right')
        plt.savefig(name + '.jpg')
        plt.close()
Exemplo n.º 2
0
class Test:
    def __init__(self, model_name, snapshot, num_classes):

        self.num_classes = num_classes
        if model_name == "resnet50":
            self.model = ResNet(torchvision.models.resnet50(pretrained=False),
                                num_classes)
        elif model_name == "mobilenetv2":
            self.model = MobileNetV2(num_classes=num_classes)
        else:
            print("No such model...")
            exit(0)

        self.saved_state_dict = torch.load(snapshot)
        self.model.load_state_dict(self.saved_state_dict)
        self.model.cuda(0)
        self.model.eval()  # Change model to 'eval' mode

        self.softmax = nn.Softmax(dim=1).cuda(0)

        self.optimizer = Optimize()

    def draw_vectors(self, pred_vector1, pred_vector2, pred_vector3, img,
                     center, width):

        optimize_v = self.optimizer.Get_Ortho_Vectors(np.array(pred_vector1),
                                                      np.array(pred_vector2),
                                                      np.array(pred_vector3))
        v1, v2, v3 = optimize_v[0], optimize_v[1], optimize_v[2]

        # draw vector in blue color
        predx, predy, predz = v1
        utils.draw_front(img,
                         predx,
                         predy,
                         width,
                         tdx=center[0],
                         tdy=center[1],
                         size=100,
                         color=(255, 0, 0))

        # draw vector in green color
        predx, predy, predz = v2
        utils.draw_front(img,
                         predx,
                         predy,
                         width,
                         tdx=center[0],
                         tdy=center[1],
                         size=100,
                         color=(0, 255, 0))

        # draw vector in red color
        predx, predy, predz = v3
        utils.draw_front(img,
                         predx,
                         predy,
                         width,
                         tdx=center[0],
                         tdy=center[1],
                         size=100,
                         color=(0, 0, 255))

        cv.imshow("pose visualization", img)

    def test_per_img(self, cv_img, draw_img, center, w):
        with torch.no_grad():
            images = cv_img.cuda(0)

            # get x,y,z cls predictions
            x_v1, y_v1, z_v1, x_v2, y_v2, z_v2, x_v3, y_v3, z_v3 = self.model(
                images)

            # get prediction vector(get continue value from classify result)
            _, _, _, pred_vector1 = utils.classify2vector(
                x_v1, y_v1, z_v1, self.softmax, self.num_classes)
            _, _, _, pred_vector2 = utils.classify2vector(
                x_v2, y_v2, z_v2, self.softmax, self.num_classes)
            _, _, _, pred_vector3 = utils.classify2vector(
                x_v3, y_v3, z_v3, self.softmax, self.num_classes)

            #visualize vectors
            self.draw_vectors(pred_vector1[0].cpu().tolist(),
                              pred_vector2[0].cpu().tolist(),
                              pred_vector3[0].cpu().tolist(), draw_img, center,
                              w)
Exemplo n.º 3
0
    # 读取checkpoint保存的模型,在验证集上跑一遍,计算准确率和召回率
    path = './models/ckpoint_' + name + '.tar'
    checkpoint = torch.load(path)
    # input_shape = (c.hoa_num * 2, c.frames_per_block+2, c.n_freq)
    if DATA_TYPE == 'hoa':
        block = ResBlock(128, 128, bottleneck=bottle)
    else:
        block = ResBlock(256, 256, bottleneck=bottle)
    model = ResNet(block,
                   numOfResBlock=num_of_res_blocks,
                   input_shape=input_shape,
                   data_type=DATA_TYPE).to(device)
    model.load_state_dict(checkpoint['state_dict'])
    criterion = nn.MSELoss()
    model.eval()
    f = open(name + '.txt', 'w')
    file_ = open('anechoic_mono_speech_tt.flist', 'r')
    all_test_files = file_.readlines()
    with torch.no_grad():
        for snr in snr_list:
            offset = index_offset_dict[snr]
            total_correct = np.zeros(numOfEth)  # 每个误差容忍度都对应一个准确个数
            total_recall = np.zeros(numOfEth)
            total_peaks, total_predict = 0, 0  # 总共的真实峰值数,总共预测出来的峰值数

            valid_step_cnt = 0
            total_valid_loss = 0.0
            no_peak = np.zeros(c.scan_num)  # 记录无峰值输出(全0输出)的样本个数
            total = 0  # 验证集的总样本个数
    x = x.squeeze(0)
    x = np.transpose(x, (1,2,0))
    x = normalization(x)
    x = preprocess_image(x, resize_im=False)
    x = recreate_image(x)
    if not os.path.exists('./deconv/'+ png_dir):
        os.makedirs('./deconv/'+ png_dir)
    im_path = './deconv/'+ png_dir+ '/layer_vis_' + str(demode) +'_' + str(index)+ '.jpg'
    save_image(x, im_path)


if __name__ == '__main__':

    #get model and data
    net = ResNet()
    net.eval()
    net = net.cuda()
    net.load_state_dict(torch.load('./cifar_net.pth'))
#     print(net.state_dict().keys())
#     for name, module in net._modules.items():
#         for name, module in module._modules.items():
#             print(name)
#     for name, param in net.named_parameters():
#         print(name)
#     net2 = ResNet_deconv(demode=1)
#     net2 = net2.cuda()
#     encorder = ResNet_encorder(demode=2)
#     encorder = encorder.cuda()
#     decorder = ResNet_decorder2(demode=2)
#     decorder = decorder.cuda()
#     params=net.state_dict() 
Exemplo n.º 5
0
    utils.mkdir(args.save_dir)

    # cls and sord
    print("Creating model......")
    if args.model_name == "mobilenetv2":
        model = MobileNetV2(num_classes=args.num_classes)
    else:
        model = ResNet(torchvision.models.resnet50(pretrained=False),
                       args.num_classes)

    print("Loading weight......")
    saved_state_dict = torch.load(args.snapshot)
    model.load_state_dict(saved_state_dict)
    model.cuda(0)

    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).

    softmax = nn.Softmax(dim=1).cuda(0)

    # test dataLoader
    test_loader = loadData(args.test_data, args.input_size, args.batch_size,
                           args.num_classes, False)

    # testing
    print('Start testing......')

    if args.collect_score:
        utils.mkdir(os.path.join(args.save_dir, "collect_score"))
    test(model, test_loader, softmax, args)
def train(img_dir, xml_dir, epochs, input_size, batch_size, num_classes):
    """
    params: 
          bins: number of bins for classification
          alpha: regression loss weight
          beta: ortho loss weight
    """
    # create model
    model = ResNet(torchvision.models.resnet50(pretrained=True),
                   num_classes=num_classes)

    cls_criterion = nn.CrossEntropyLoss().cuda(1)

    softmax = nn.Softmax(dim=1).cuda(1)
    model.cuda(1)

    # initialize learning rate and step
    lr = 0.001
    step = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    #load data
    train_data_loader = loadData(img_dir, xml_dir, input_size, batch_size,
                                 True)
    test_loader = loadData('../yolov3/data/test_imgs',
                           '../yolov3/data/test_anns', 224, 8, False)

    #variables
    history = []
    best_acc = 0.0
    best_epoch = 0

    # start training
    for epoch in range(epochs):
        print("Epoch:", epoch)
        print("------------")

        # reduce lr by lr_decay factor for each epoch
        if epoch % 10 == 0:
            lr = lr * 0.9

        train_loss = 0.0
        train_acc = 0
        val_acc = 0

        model.train()

        for i, (images, labels) in enumerate(train_data_loader):
            if i % 10 == 0:
                print("batch: {}/{}".format(
                    i,
                    len(train_data_loader.dataset) // batch_size))
            images = images.cuda(1)
            labels = labels.cuda(1)

            # backward
            optimizer.zero_grad()
            outputs = model(images)

            loss = cls_criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))

            acc = torch.mean(correct_counts.type(torch.FloatTensor))

            train_acc += acc.item() * images.size(0)

        print("epoch: {:03d}, Training loss: {:.4f}, Accuracy: {:.4f}%".format(
            epoch + 1, train_loss, train_acc / 3096 * 100))

        #if (epoch+1) % 3 == 0:
        #    torch.save(model, 'models/'+'model_'+str(epoch+1)+'.pt')
        print("Start testing...")
        with torch.no_grad():
            model.eval()

            for j, (images, labels) in enumerate(test_loader):
                images = images.cuda(1)
                labels = labels.cuda(1)

                outputs = model(images)

                ret, preds = torch.max(outputs.data, 1)
                cnt = preds.eq(labels.data.view_as(preds))

                acc = torch.mean(cnt.type(torch.FloatTensor))
                val_acc += acc.item() * images.size(0)

            if val_acc > best_acc:
                print("correct testing samples:", val_acc)
                best_acc = val_acc
                torch.save(model,
                           'models/' + 'model_' + str(epoch + 1) + '.pt')
Exemplo n.º 7
0
class ResNetPredictor:
    def __init__(self, model_path=None):
        """
        Params:
        model_path: Optional pretrained model file
        """
        # Initialize model
        self.model = ResNet().cuda()

        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path))
            print('Model read from {}.'.format(model_path))

        print('Predictor initialized.')

    def fit(self, train_dataset_path, valid_dataset_path, model_dir, **training_args):
        """
        train_dataset_path: The path to the training dataset.pkl
        valid_dataset_path: The path to the validation dataset.pkl
        model_dir: The directory to save models for each epoch
        training_args:
          - batch_size
          - valid_batch_size
          - epoch
          - lr
          - save_every_epoch
        """
        # Set paths
        self.train_dataset_path = train_dataset_path
        self.valid_dataset_path = valid_dataset_path
        self.model_dir = model_dir
        Path(self.model_dir).mkdir(parents=True, exist_ok=True)

        # Set training params
        self.batch_size = training_args['batch_size']
        self.valid_batch_size = training_args['valid_batch_size']
        self.epoch = training_args['epoch']
        self.lr = training_args['lr']
        self.save_every_epoch = training_args['save_every_epoch']

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.onset_criterion = nn.BCEWithLogitsLoss()
        self.offset_criterion = nn.BCEWithLogitsLoss()
        self.pitch_criterion = nn.CrossEntropyLoss()

        # Read the datasets
        print('Reading datasets...')
        with open(self.train_dataset_path, 'rb') as f:
            self.training_dataset = pickle.load(f)
        with open(self.valid_dataset_path, 'rb') as f:
            self.validation_dataset = pickle.load(f)

        # Setup dataloader and initial variables
        self.train_loader = DataLoader(
            self.training_dataset,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
            shuffle=True,
            drop_last=True,
        )
        self.valid_loader = DataLoader(
            self.validation_dataset,
            batch_size=self.valid_batch_size,
            num_workers=4,
            pin_memory=True,
            shuffle=False,
            drop_last=False,
        )

        start_time = time.time()
        training_loss_list = []
        valid_loss_list = []

        # Start training
        self.iters_per_epoch = len(self.train_loader)
        for epoch in range(1, self.epoch + 1):
            self.model.train()

            # Run iterations
            total_training_loss = 0
            for batch_idx, batch in enumerate(self.train_loader):
                self.optimizer.zero_grad()

                # Parse batch data
                input_tensor = batch[0].permute(0, 2, 1).unsqueeze(1).cuda()
                osnet_prob, offset_prob, pitch_class = batch[1][:, 0].float().cuda(), batch[1][:, 1].float().cuda(), batch[1][:, 2].cuda()

                # Forward model
                onset_logits, offset_logits, pitch_logits = self.model(input_tensor)

                # Calculate loss
                loss = self.onset_criterion(onset_logits, osnet_prob) \
                    + self.offset_criterion(offset_logits, offset_prob) \
                    + self.pitch_criterion(pitch_logits, pitch_class)

                loss.backward()
                self.optimizer.step()

                total_training_loss += loss.item()

                # Free GPU memory
                # torch.cuda.empty_cache()

            if epoch % self.save_every_epoch == 0:
                # Perform validation
                self.model.eval()
                with torch.no_grad():
                    total_valid_loss = 0
                    for batch_idx, batch in enumerate(self.valid_loader):
                        # Parse batch data
                        input_tensor = batch[0].permute(0, 2, 1).unsqueeze(1).cuda()
                        osnet_prob, offset_prob, pitch_class = batch[1][:, 0].float().cuda(), batch[1][:, 1].float().cuda(), batch[1][:, 2].cuda()

                        # Forward model
                        onset_logits, offset_logits, pitch_logits = self.model(input_tensor)

                        # Calculate loss
                        loss = self.onset_criterion(onset_logits, osnet_prob) \
                            + self.offset_criterion(offset_logits, offset_prob) \
                            + self.pitch_criterion(pitch_logits, pitch_class)
                        total_valid_loss += loss.item()

                        # Free GPU memory
                        # torch.cuda.empty_cache()

                # Save model
                save_dict = self.model.state_dict()
                target_model_path = Path(self.model_dir) / 'e_{}'.format(epoch)
                torch.save(save_dict, target_model_path)

                # Save loss list
                training_loss_list.append((epoch, total_training_loss/self.iters_per_epoch))
                valid_loss_list.append((epoch, total_valid_loss/len(self.valid_loader)))

                # Epoch statistics
                print(
                    '| Epoch [{:4d}/{:4d}] Train Loss {:.4f} Valid Loss {:.4f} Time {:.1f}'.format(
                        epoch,
                        self.epoch,
                        training_loss_list[-1][1],
                        valid_loss_list[-1][1],
                        time.time()-start_time,
                    )
                )

        # Save loss to file
        with open('./plotting/data/loss.pkl', 'wb') as f:
            pickle.dump({'train': training_loss_list, 'valid': valid_loss_list}, f)

        print('Training done in {:.1f} minutes.'.format((time.time()-start_time)/60))

    def _parse_frame_info(self, frame_info):
        """Parse frame info [(onset_probs, offset_probs, pitch_class)...] into desired label format."""
        onset_thres = 0.25
        offset_thres = 0.25

        result = []
        current_onset = None
        pitch_counter = Counter()
        last_onset = 0.0
        for idx, info in enumerate(frame_info):
            current_time = FRAME_LENGTH*idx + FRAME_LENGTH/2

            if info[0] >= onset_thres:  # If is onset
                if current_onset is None:
                    current_onset = current_time
                    last_onset = info[0]
                elif info[0] >= onset_thres:
                    # If current_onset exists, make this onset a offset and the next current_onset
                    if pitch_counter.most_common(1)[0][0] != 49:
                        result.append([current_onset, current_time, pitch_counter.most_common(1)[0][0] + 36])
                    elif len(pitch_counter.most_common(2)) == 2:
                        result.append([current_onset, current_time, pitch_counter.most_common(2)[1][0] + 36])
                    current_onset = current_time
                    last_onset = info[0]
                    pitch_counter.clear()
            elif info[1] >= offset_thres:  # If is offset
                if current_onset is not None:
                    if pitch_counter.most_common(1)[0][0] != 49:
                        result.append([current_onset, current_time, pitch_counter.most_common(1)[0][0] + 36])
                    elif len(pitch_counter.most_common(2)) == 2:
                        result.append([current_onset, current_time, pitch_counter.most_common(2)[1][0] + 36])
                    current_onset = None
                    pitch_counter.clear()

            # If current_onset exist, add count for the pitch
            if current_onset is not None:
                pitch_counter[info[2]] += 1

        return result

    def predict(self, test_dataset):
        """Predict results for a given test dataset."""
        # Setup params and dataloader
        batch_size = 500
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            pin_memory=False,
            shuffle=False,
            drop_last=False,
        )

        # Start predicting
        results = []
        self.model.eval()
        with torch.no_grad():
            print('Forwarding model...')
            song_frames_table = {}
            for batch_idx, batch in enumerate(tqdm(test_loader)):
                # Parse batch data
                input_tensor = batch[0].unsqueeze(1).cuda()
                song_ids = batch[1]

                # Forward model
                onset_logits, offset_logits, pitch_logits = self.model(input_tensor)
                onset_probs, offset_probs, pitch_logits = torch.sigmoid(onset_logits).cpu(), torch.sigmoid(offset_logits).cpu(), pitch_logits.cpu()

                # Collect frames for corresponding songs
                for bid, song_id in enumerate(song_ids):
                    frame_info = (onset_probs[bid], offset_probs[bid], torch.argmax(pitch_logits[bid]).item())
                    song_frames_table.setdefault(song_id, [])
                    song_frames_table[song_id].append(frame_info)

            # Parse frame info into output format for every song
            results = {}
            for song_id, frame_info in song_frames_table.items():
                results[song_id] = self._parse_frame_info(frame_info)

        return results