Ejemplo n.º 1
0
def train_and_valid(learning_rate=lr,
                    weight_decay=weight_decay,
                    num_of_res=num_of_res_blocks,
                    if_bottleneck=bottle,
                    plot=True):
    """
    Train the model and run it on the valid set every epoch
    :param weight_decay: for L2 regularzition
    :param bottleneck:
    :param num_of_res:
    :param learning_rate: lr
    :param plot: draw the train/valid loss curve or not
    :return:
    """
    curr_lr = learning_rate
    # model define
    if NET_TYPE == 'res':
        if DATA_TYPE == 'hoa':
            block = ResBlock(128, 128, bottleneck=if_bottleneck)
        else:
            block = ResBlock(256, 256, bottleneck=if_bottleneck)
        model = ResNet(block,
                       numOfResBlock=num_of_res,
                       input_shape=input_shape,
                       data_type=DATA_TYPE).to(device)
    elif NET_TYPE == 'hoa':
        model = HOANet(input_shape=input_shape).to(device)
    else:
        raise RuntimeError('Unrecognized net type!')
    # print(model)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=weight_decay)

    # These parameters are for searching the best epoch to early stopping
    train_loss_curve, valid_loss_curve = [], []
    best_loss, avr_valid_loss = 10000.0, 0.0

    best_epoch = 0
    best_model = None  # the best parameters

    for epoch in range(num_epochs):
        # 每一轮的 训练集/验证集 误差
        train_loss_per_epoch, valid_loss_per_epoch = 0.0, 0.0
        train_step_cnt, valid_step_cnt = 0, 0

        train_data, valid_data = [], []
        # 进入训练模式
        model.train()
        random.shuffle(train_file_order)

        for idx, train_idx in enumerate(train_file_order):
            if len(train_data) < batch_size:
                train_data_temp = HOADataSet(
                    path=DATA_PATH + ('' if DATA_TYPE == 'hoa' else 'STFT/') +
                    'tr/',
                    index=train_idx + 1,
                    data_type=DATA_TYPE,
                    is_speech=SPEECH)
                if len(train_data) == 0:
                    train_data = train_data_temp
                else:
                    train_data += train_data_temp
                continue

            train_loader = data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

            for step, (examples, labels) in enumerate(train_loader):
                # if step == 1:
                #     break
                train_step_cnt += 1
                # print(train_step_cnt)
                examples = examples.float().to(device)
                labels = labels.float().to(device)
                outputs = model(examples)
                train_loss = criterion(outputs, labels)
                train_loss_per_epoch += train_loss.item()

                # Backward and optimize
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                logger.info(
                    "Epoch [{}/{}], Step {}, train Loss: {:.4f}".format(
                        epoch + 1, num_epochs, train_step_cnt,
                        train_loss.item()))

            train_data = HOADataSet(path=DATA_PATH +
                                    ('' if DATA_TYPE == 'hoa' else 'STFT/') +
                                    'tr/',
                                    index=train_idx + 1,
                                    data_type=DATA_TYPE,
                                    is_speech=SPEECH)

        if plot:
            train_loss_curve.append(train_loss_per_epoch / train_step_cnt)

        if running_lr and epoch > 1 and (epoch + 1) % 2 == 0:
            curr_lr = curr_lr * (1 - decay)
            update_lr(optimizer, curr_lr)

        # valid every epoch
        # 进入验证模式

        model.eval()
        with torch.no_grad():
            for idx, valid_idx in enumerate(valid_file_order):
                if len(valid_data) < batch_size:
                    valid_data_temp = HOADataSet(
                        path=DATA_PATH +
                        ('' if DATA_TYPE == 'hoa' else 'STFT/') + 'cv/',
                        index=valid_idx + 1,
                        data_type=DATA_TYPE,
                        is_speech=SPEECH)
                    if len(valid_data) == 0:
                        valid_data = valid_data_temp
                    else:
                        valid_data += valid_data_temp
                    continue

                valid_loader = data.DataLoader(dataset=valid_data,
                                               batch_size=batch_size,
                                               shuffle=True)

                for step, (examples, labels) in enumerate(valid_loader):
                    valid_step_cnt += 1
                    # print(valid_step_cnt)
                    examples = examples.float().to(device)
                    labels = labels.float().to(device)

                    outputs = model(examples)
                    valid_loss = criterion(outputs, labels)
                    valid_loss_per_epoch += valid_loss.item()

                    logger.info(
                        'The loss for the current batch:{}'.format(valid_loss))

                valid_data = HOADataSet(
                    path=DATA_PATH + ('' if DATA_TYPE == 'hoa' else 'STFT/') +
                    'cv/',
                    index=valid_idx + 1,
                    data_type=DATA_TYPE,
                    is_speech=SPEECH)

            avr_valid_loss = valid_loss_per_epoch / valid_step_cnt

            logger.info(
                'Epoch {}, the average loss on the valid set: {} '.format(
                    epoch, avr_valid_loss))

            valid_loss_curve.append(avr_valid_loss)
            if avr_valid_loss < best_loss:
                best_loss = avr_valid_loss
                best_epoch, best_model = epoch, model.state_dict()

    # end for loop of epoch
    torch.save(
        {
            'epoch': best_epoch,
            'state_dict': best_model,
            'loss': best_loss,
        }, './models/ckpoint_' + CUR_TASK + '_bot' + str(int(if_bottleneck)) +
        '_lr' + str(learning_rate) + '_wd' + str(weight_decay) + '_#res' +
        str(num_of_res) + '.tar')

    logger.info('best epoch:{}, valid loss:{}'.format(best_epoch, best_loss))
    if plot:
        x = np.arange(num_epochs)
        fig, ax = plt.subplots(1, 1)
        ax.plot(x, train_loss_curve, 'b', label='Train Loss')
        ax.plot(x, valid_loss_curve, 'r', label='Valid Loss')
        plt.legend(loc='upper right')
        plt.savefig(name + '.jpg')
        plt.close()
Ejemplo n.º 2
0
def train(img_dir, xml_dir, epochs, input_size, batch_size, num_classes):
    """
    params: 
          bins: number of bins for classification
          alpha: regression loss weight
          beta: ortho loss weight
    """
    # create model
    model = ResNet(torchvision.models.resnet50(pretrained=True),
                   num_classes=num_classes)

    cls_criterion = nn.CrossEntropyLoss().cuda(1)

    softmax = nn.Softmax(dim=1).cuda(1)
    model.cuda(1)

    # initialize learning rate and step
    lr = 0.001
    step = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    #load data
    train_data_loader = loadData(img_dir, xml_dir, input_size, batch_size,
                                 True)
    test_loader = loadData('../yolov3/data/test_imgs',
                           '../yolov3/data/test_anns', 224, 8, False)

    #variables
    history = []
    best_acc = 0.0
    best_epoch = 0

    # start training
    for epoch in range(epochs):
        print("Epoch:", epoch)
        print("------------")

        # reduce lr by lr_decay factor for each epoch
        if epoch % 10 == 0:
            lr = lr * 0.9

        train_loss = 0.0
        train_acc = 0
        val_acc = 0

        model.train()

        for i, (images, labels) in enumerate(train_data_loader):
            if i % 10 == 0:
                print("batch: {}/{}".format(
                    i,
                    len(train_data_loader.dataset) // batch_size))
            images = images.cuda(1)
            labels = labels.cuda(1)

            # backward
            optimizer.zero_grad()
            outputs = model(images)

            loss = cls_criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))

            acc = torch.mean(correct_counts.type(torch.FloatTensor))

            train_acc += acc.item() * images.size(0)

        print("epoch: {:03d}, Training loss: {:.4f}, Accuracy: {:.4f}%".format(
            epoch + 1, train_loss, train_acc / 3096 * 100))

        #if (epoch+1) % 3 == 0:
        #    torch.save(model, 'models/'+'model_'+str(epoch+1)+'.pt')
        print("Start testing...")
        with torch.no_grad():
            model.eval()

            for j, (images, labels) in enumerate(test_loader):
                images = images.cuda(1)
                labels = labels.cuda(1)

                outputs = model(images)

                ret, preds = torch.max(outputs.data, 1)
                cnt = preds.eq(labels.data.view_as(preds))

                acc = torch.mean(cnt.type(torch.FloatTensor))
                val_acc += acc.item() * images.size(0)

            if val_acc > best_acc:
                print("correct testing samples:", val_acc)
                best_acc = val_acc
                torch.save(model,
                           'models/' + 'model_' + str(epoch + 1) + '.pt')
Ejemplo n.º 3
0
def train(net, bins, alpha, beta, batch_size):
    """
    params: 
          bins: number of bins for classification
          alpha: regression loss weight
          beta: ortho loss weight
    """
    # create model
    if net == "resnet50":
        model = ResNet(torchvision.models.resnet50(pretrained=False),
                       num_classes=bins)
        lr = args.lr_resnet
    else:
        model = MobileNetV2(torchvision.models.mobilenet_v2(pretrained=True),
                            num_classes=bins)
        lr = args.lr_mobilenet

    # loading data
    logger.logger.info("Loading data".center(100, '='))
    train_data_loader = loadData(args.train_data, args.input_size, batch_size,
                                 bins)
    valid_data_loader = loadData(args.valid_data, args.input_size, batch_size,
                                 bins, False)

    # initialize cls loss function
    if args.cls_loss == "KLDiv":
        cls_criterion = nn.KLDivLoss(reduction='batchmean').cuda(0)
    elif args.cls_loss == "BCE":
        cls_criterion = nn.BCELoss().cuda(0)
    elif args.cls_loss == 'FocalLoss':
        cls_criterion = FocalLoss(bins).cuda(0)
    elif args.cls_loss == 'CrossEntropy':
        cls_criterion = nn.CrossEntropyLoss().cuda(0)

    # initialize reg loss function
    reg_criterion = nn.MSELoss().cuda(0)
    softmax = nn.Softmax(dim=1).cuda(0)
    sigmoid = nn.Sigmoid().cuda(0)
    model.cuda(0)

    # training log
    logger.logger.info("Training".center(100, '='))

    # initialize learning rate and step
    lr = lr
    step = 0

    # validation error
    min_avg_error = 1000.

    # start training
    for epoch in range(args.epochs):
        print("Epoch:", epoch)
        model.train()
        # learning rate initialization
        if net == 'resnet50':
            if epoch >= args.unfreeze:
                optimizer = torch.optim.Adam(
                    [{
                        "params": get_non_ignored_params(model, net),
                        "lr": lr
                    }, {
                        "params": get_cls_fc_params(model),
                        "lr": lr * 10
                    }],
                    lr=args.lr_resnet)
            else:
                optimizer = torch.optim.Adam(
                    [{
                        "params": get_non_ignored_params(model, net),
                        "lr": lr
                    }, {
                        "params": get_cls_fc_params(model),
                        "lr": lr * 10
                    }],
                    lr=args.lr_resnet)

        else:
            if epoch >= args.unfreeze:
                optimizer = torch.optim.Adam(
                    [{
                        "params": get_non_ignored_params(model, net),
                        "lr": lr
                    }, {
                        "params": get_cls_fc_params(model),
                        "lr": lr
                    }],
                    lr=args.lr_mobilenet)
            else:
                optimizer = torch.optim.Adam(
                    [{
                        "params": get_non_ignored_params(model, net),
                        "lr": lr * 10
                    }, {
                        "params": get_cls_fc_params(model),
                        "lr": lr * 10
                    }],
                    lr=args.lr_mobilenet)

        # reduce lr by lr_decay factor for each epoch
        lr = lr * args.lr_decay
        print("------------")

        for i, (images, cls_v1, cls_v2, cls_v3, reg_v1, reg_v2, reg_v3, name,
                left_targets, down_targets,
                front_targets) in enumerate(train_data_loader):
            images = images.cuda(0).float()

            # get classified labels
            cls_v1 = cls_v1.cuda(0)
            cls_v2 = cls_v2.cuda(0)
            cls_v3 = cls_v3.cuda(0)

            # get continuous labels
            reg_v1 = reg_v1.cuda(0)
            reg_v2 = reg_v2.cuda(0)
            reg_v3 = reg_v3.cuda(0)

            left_targets = left_targets.cuda(0)
            down_targets = down_targets.cuda(0)
            front_targets = front_targets.cuda(0)

            # inference
            x_pred_v1, y_pred_v1, z_pred_v1, x_pred_v2, y_pred_v2, z_pred_v2, x_pred_v3, y_pred_v3, z_pred_v3 = model(
                images)

            logits = [
                x_pred_v1, y_pred_v1, z_pred_v1, x_pred_v2, y_pred_v2,
                z_pred_v2, x_pred_v3, y_pred_v3, z_pred_v3
            ]

            loss, degree_error_v1, degree_error_v2, degree_error_v3 = utils.computeLoss(
                cls_v1, cls_v2, cls_v3, reg_v1, reg_v2, reg_v3, logits,
                softmax, sigmoid, cls_criterion, reg_criterion, left_targets,
                down_targets, front_targets, [
                    bins, alpha, beta, args.cls_loss, args.reg_loss,
                    args.ortho_loss
                ])

            # backward
            grad = [torch.tensor(1.0).cuda(0) for _ in range(3)]
            optimizer.zero_grad()
            torch.autograd.backward(loss, grad)
            optimizer.step()

            # save training log and weight
            if (i + 1) % 500 == 0:
                msg = "Epoch: %d/%d | Iter: %d/%d | x_loss: %.6f | y_loss: %.6f | z_loss: %.6f | degree_error_f:%.3f | degree_error_r:%.3f | degree_error_u:%.3f" % (
                    epoch, args.epochs, i + 1, len(train_data_loader.dataset)
                    // batch_size, loss[0].item(), loss[1].item(),
                    loss[2].item(), degree_error_v1.item(),
                    degree_error_v2.item(), degree_error_v3.item())
                print(msg)
                logger.logger.info(msg)

        # Test on validation dataset
        error_v1, error_v2, error_v3 = valid(model, valid_data_loader, softmax,
                                             bins)
        print("Epoch:", epoch)
        print("Validation Error:", error_v1.item(), error_v2.item(),
              error_v3.item())
        logger.logger.info("Validation Error(l,d,f)_{},{},{}".format(
            error_v1.item(), error_v2.item(), error_v3.item()))

        # save model if achieve better validation performance
        if error_v1.item() + error_v2.item() + error_v3.item() < min_avg_error:

            min_avg_error = error_v1.item() + error_v2.item() + error_v3.item()
            print("Training Info:")
            print("Model:", net, " ", "Number of bins:", bins, " ", "Alpha:",
                  alpha, " ", "Beta:", beta)
            print("Saving Model......")
            torch.save(
                model.state_dict(),
                os.path.join(snapshot_dir, output_string + '_Best_' + '.pkl'))
            print("Saved")
Ejemplo n.º 4
0
class ResNetPredictor:
    def __init__(self, model_path=None):
        """
        Params:
        model_path: Optional pretrained model file
        """
        # Initialize model
        self.model = ResNet().cuda()

        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path))
            print('Model read from {}.'.format(model_path))

        print('Predictor initialized.')

    def fit(self, train_dataset_path, valid_dataset_path, model_dir, **training_args):
        """
        train_dataset_path: The path to the training dataset.pkl
        valid_dataset_path: The path to the validation dataset.pkl
        model_dir: The directory to save models for each epoch
        training_args:
          - batch_size
          - valid_batch_size
          - epoch
          - lr
          - save_every_epoch
        """
        # Set paths
        self.train_dataset_path = train_dataset_path
        self.valid_dataset_path = valid_dataset_path
        self.model_dir = model_dir
        Path(self.model_dir).mkdir(parents=True, exist_ok=True)

        # Set training params
        self.batch_size = training_args['batch_size']
        self.valid_batch_size = training_args['valid_batch_size']
        self.epoch = training_args['epoch']
        self.lr = training_args['lr']
        self.save_every_epoch = training_args['save_every_epoch']

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.onset_criterion = nn.BCEWithLogitsLoss()
        self.offset_criterion = nn.BCEWithLogitsLoss()
        self.pitch_criterion = nn.CrossEntropyLoss()

        # Read the datasets
        print('Reading datasets...')
        with open(self.train_dataset_path, 'rb') as f:
            self.training_dataset = pickle.load(f)
        with open(self.valid_dataset_path, 'rb') as f:
            self.validation_dataset = pickle.load(f)

        # Setup dataloader and initial variables
        self.train_loader = DataLoader(
            self.training_dataset,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
            shuffle=True,
            drop_last=True,
        )
        self.valid_loader = DataLoader(
            self.validation_dataset,
            batch_size=self.valid_batch_size,
            num_workers=4,
            pin_memory=True,
            shuffle=False,
            drop_last=False,
        )

        start_time = time.time()
        training_loss_list = []
        valid_loss_list = []

        # Start training
        self.iters_per_epoch = len(self.train_loader)
        for epoch in range(1, self.epoch + 1):
            self.model.train()

            # Run iterations
            total_training_loss = 0
            for batch_idx, batch in enumerate(self.train_loader):
                self.optimizer.zero_grad()

                # Parse batch data
                input_tensor = batch[0].permute(0, 2, 1).unsqueeze(1).cuda()
                osnet_prob, offset_prob, pitch_class = batch[1][:, 0].float().cuda(), batch[1][:, 1].float().cuda(), batch[1][:, 2].cuda()

                # Forward model
                onset_logits, offset_logits, pitch_logits = self.model(input_tensor)

                # Calculate loss
                loss = self.onset_criterion(onset_logits, osnet_prob) \
                    + self.offset_criterion(offset_logits, offset_prob) \
                    + self.pitch_criterion(pitch_logits, pitch_class)

                loss.backward()
                self.optimizer.step()

                total_training_loss += loss.item()

                # Free GPU memory
                # torch.cuda.empty_cache()

            if epoch % self.save_every_epoch == 0:
                # Perform validation
                self.model.eval()
                with torch.no_grad():
                    total_valid_loss = 0
                    for batch_idx, batch in enumerate(self.valid_loader):
                        # Parse batch data
                        input_tensor = batch[0].permute(0, 2, 1).unsqueeze(1).cuda()
                        osnet_prob, offset_prob, pitch_class = batch[1][:, 0].float().cuda(), batch[1][:, 1].float().cuda(), batch[1][:, 2].cuda()

                        # Forward model
                        onset_logits, offset_logits, pitch_logits = self.model(input_tensor)

                        # Calculate loss
                        loss = self.onset_criterion(onset_logits, osnet_prob) \
                            + self.offset_criterion(offset_logits, offset_prob) \
                            + self.pitch_criterion(pitch_logits, pitch_class)
                        total_valid_loss += loss.item()

                        # Free GPU memory
                        # torch.cuda.empty_cache()

                # Save model
                save_dict = self.model.state_dict()
                target_model_path = Path(self.model_dir) / 'e_{}'.format(epoch)
                torch.save(save_dict, target_model_path)

                # Save loss list
                training_loss_list.append((epoch, total_training_loss/self.iters_per_epoch))
                valid_loss_list.append((epoch, total_valid_loss/len(self.valid_loader)))

                # Epoch statistics
                print(
                    '| Epoch [{:4d}/{:4d}] Train Loss {:.4f} Valid Loss {:.4f} Time {:.1f}'.format(
                        epoch,
                        self.epoch,
                        training_loss_list[-1][1],
                        valid_loss_list[-1][1],
                        time.time()-start_time,
                    )
                )

        # Save loss to file
        with open('./plotting/data/loss.pkl', 'wb') as f:
            pickle.dump({'train': training_loss_list, 'valid': valid_loss_list}, f)

        print('Training done in {:.1f} minutes.'.format((time.time()-start_time)/60))

    def _parse_frame_info(self, frame_info):
        """Parse frame info [(onset_probs, offset_probs, pitch_class)...] into desired label format."""
        onset_thres = 0.25
        offset_thres = 0.25

        result = []
        current_onset = None
        pitch_counter = Counter()
        last_onset = 0.0
        for idx, info in enumerate(frame_info):
            current_time = FRAME_LENGTH*idx + FRAME_LENGTH/2

            if info[0] >= onset_thres:  # If is onset
                if current_onset is None:
                    current_onset = current_time
                    last_onset = info[0]
                elif info[0] >= onset_thres:
                    # If current_onset exists, make this onset a offset and the next current_onset
                    if pitch_counter.most_common(1)[0][0] != 49:
                        result.append([current_onset, current_time, pitch_counter.most_common(1)[0][0] + 36])
                    elif len(pitch_counter.most_common(2)) == 2:
                        result.append([current_onset, current_time, pitch_counter.most_common(2)[1][0] + 36])
                    current_onset = current_time
                    last_onset = info[0]
                    pitch_counter.clear()
            elif info[1] >= offset_thres:  # If is offset
                if current_onset is not None:
                    if pitch_counter.most_common(1)[0][0] != 49:
                        result.append([current_onset, current_time, pitch_counter.most_common(1)[0][0] + 36])
                    elif len(pitch_counter.most_common(2)) == 2:
                        result.append([current_onset, current_time, pitch_counter.most_common(2)[1][0] + 36])
                    current_onset = None
                    pitch_counter.clear()

            # If current_onset exist, add count for the pitch
            if current_onset is not None:
                pitch_counter[info[2]] += 1

        return result

    def predict(self, test_dataset):
        """Predict results for a given test dataset."""
        # Setup params and dataloader
        batch_size = 500
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            pin_memory=False,
            shuffle=False,
            drop_last=False,
        )

        # Start predicting
        results = []
        self.model.eval()
        with torch.no_grad():
            print('Forwarding model...')
            song_frames_table = {}
            for batch_idx, batch in enumerate(tqdm(test_loader)):
                # Parse batch data
                input_tensor = batch[0].unsqueeze(1).cuda()
                song_ids = batch[1]

                # Forward model
                onset_logits, offset_logits, pitch_logits = self.model(input_tensor)
                onset_probs, offset_probs, pitch_logits = torch.sigmoid(onset_logits).cpu(), torch.sigmoid(offset_logits).cpu(), pitch_logits.cpu()

                # Collect frames for corresponding songs
                for bid, song_id in enumerate(song_ids):
                    frame_info = (onset_probs[bid], offset_probs[bid], torch.argmax(pitch_logits[bid]).item())
                    song_frames_table.setdefault(song_id, [])
                    song_frames_table[song_id].append(frame_info)

            # Parse frame info into output format for every song
            results = {}
            for song_id, frame_info in song_frames_table.items():
                results[song_id] = self._parse_frame_info(frame_info)

        return results