Exemple #1
0
def validate(val_loader, model, args, test=False):
    def compute_accuracy(gt_results, results, metric='micro'):
        return sklearn.metrics.precision_recall_fscore_support(
            gt_results, results, labels=range(10), average=metric)

    batch_time = logutil.AverageMeter()
    baseline_acc_ratio = logutil.AverageMeter()
    subactivity_acc_ratio = logutil.AverageMeter()
    seg_pred_acc_ratio = logutil.AverageMeter()
    frame_pred_acc_ratio = logutil.AverageMeter()

    all_baseline_detections = list()
    all_gt_detections = list()
    all_detections = list()
    all_gt_seg_predictions = list()
    all_gt_frame_predictions = list()
    all_seg_predictions = list()
    all_frame_predictions = list()

    # switch to evaluate mode
    model.eval()

    end_time = time.time()
    for i, (features, labels, probs, total_lengths, ctc_labels, ctc_lengths,
            activities, sequence_ids) in enumerate(val_loader):
        features = utils.to_variable(features, args.cuda)
        labels = utils.to_variable(labels, args.cuda)

        total_lengths = torch.autograd.Variable(total_lengths)

        # Inference
        model_outputs = model(features)
        pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos = inference(
            model_outputs, activities, sequence_ids, ctc_labels, args)

        # Visualize results
        for batch_i in range(labels.size()[1]):
            vizutil.plot_segmentation(
                [
                    labels[:, batch_i].squeeze(),
                    pred_labels[:, batch_i].squeeze(),
                    batch_earley_pred_labels[batch_i]
                ],
                int(total_lengths[batch_i]),
                filename=os.path.join(
                    args.tmp_root, 'visualize', 'segmentation', 'cad',
                    '{}_{}.pdf'.format(activities[batch_i],
                                       sequence_ids[batch_i])),
                border=False,
                vmax=len(datasets.cad_metadata.subactivities))

        # Evaluation
        # Frame-wise detection
        baseline_detections = pred_labels.cpu().data.numpy().flatten().tolist()
        gt_detections = labels.cpu().data.numpy().flatten().tolist()
        detections = [
            l for pred_labels in batch_earley_pred_labels
            for l in pred_labels.tolist()
        ]
        all_baseline_detections.extend(baseline_detections)
        all_gt_detections.extend(gt_detections)
        all_detections.extend(detections)
        baseline_micro_result = compute_accuracy(gt_detections,
                                                 baseline_detections)
        subact_micro_result = compute_accuracy(gt_detections, detections)

        gt_seg_predictions, gt_frame_predictions, seg_predictions, frame_predictions = predict(
            activities, total_lengths, labels, ctc_labels, batch_tokens,
            batch_seg_pos)
        all_gt_seg_predictions.extend(gt_seg_predictions)
        all_gt_frame_predictions.extend(gt_frame_predictions)
        all_seg_predictions.extend(seg_predictions)
        all_frame_predictions.extend(frame_predictions)
        seg_pred_result = compute_accuracy(gt_seg_predictions, seg_predictions)
        frame_pred_result = compute_accuracy(gt_frame_predictions,
                                             frame_predictions)

        baseline_acc_ratio.update(baseline_micro_result[0],
                                  torch.sum(total_lengths).data[0])
        subactivity_acc_ratio.update(subact_micro_result[0],
                                     torch.sum(total_lengths).data[0])
        seg_pred_acc_ratio.update(seg_pred_result[0],
                                  torch.sum(total_lengths).data[0])
        frame_pred_acc_ratio.update(frame_pred_result[0],
                                    len(all_gt_frame_predictions))

        # Measure elapsed time
        batch_time.update(time.time() - end_time)
        end_time = time.time()

    print(' * Baseline Accuracy Ratio {base_acc.avg:.3f}; '.format(
        base_acc=baseline_acc_ratio))
    print(
        ' * Detection Accuracy Ratio {act_acc.avg:.3f}; Segment Prediction Accuracy Ratio Batch Avg {seg_pred_acc.avg:.3f}; Frame Prediction Accuracy Ratio Batch Avg {frame_pred_acc.avg:.3f}; Time {b_time.avg:.3f}'
        .format(act_acc=subactivity_acc_ratio,
                seg_pred_acc=seg_pred_acc_ratio,
                frame_pred_acc=frame_pred_acc_ratio,
                b_time=batch_time))
    print(
        compute_accuracy(all_gt_detections,
                         all_baseline_detections,
                         metric='macro'))
    print(compute_accuracy(all_gt_detections, all_detections, metric='macro'))
    print(
        compute_accuracy(all_gt_seg_predictions,
                         all_seg_predictions,
                         metric='macro'))
    print(
        compute_accuracy(all_gt_frame_predictions,
                         all_frame_predictions,
                         metric='macro'))

    confusion_matrix = sklearn.metrics.confusion_matrix(
        all_gt_detections,
        all_detections,
        labels=range(len(datasets.cad_metadata.subactivities)))
    vizutil.plot_confusion_matrix(confusion_matrix,
                                  datasets.cad_metadata.subactivities[:],
                                  normalize=True,
                                  title='',
                                  filename=os.path.join(
                                      args.tmp_root, 'visualize', 'confusion',
                                      'cad', 'detection.pdf'))
    confusion_matrix = sklearn.metrics.confusion_matrix(
        all_gt_frame_predictions,
        all_frame_predictions,
        labels=range(len(datasets.cad_metadata.subactivities)))
    vizutil.plot_confusion_matrix(confusion_matrix,
                                  datasets.cad_metadata.subactivities[:],
                                  normalize=True,
                                  title='',
                                  filename=os.path.join(
                                      args.tmp_root, 'visualize', 'confusion',
                                      'cad', 'prediction_frame.pdf'))
    confusion_matrix = sklearn.metrics.confusion_matrix(
        all_gt_seg_predictions,
        all_seg_predictions,
        labels=range(len(datasets.cad_metadata.subactivities)))
    vizutil.plot_confusion_matrix(confusion_matrix,
                                  datasets.cad_metadata.subactivities[:],
                                  normalize=True,
                                  title='',
                                  filename=os.path.join(
                                      args.tmp_root, 'visualize', 'confusion',
                                      'cad', 'prediction_seg.pdf'))

    return 1.0 - subactivity_acc_ratio.avg
def validate(args, val_loader, model, mse_loss, multi_label_loss, vcocoeval, logger=None, test=False):
    if args.visualize:
        result_folder = os.path.join(args.tmp_root, 'results/VCOCO/detections/', 'top'+str(args.vis_top_k))
        if not os.path.exists(result_folder):
            os.makedirs(result_folder)

    batch_time = logutil.AverageMeter()
    losses = logutil.AverageMeter()

    y_true = np.empty((0, action_class_num))
    y_score = np.empty((0, action_class_num))
    all_results = list()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (edge_features, node_features, part_human_id, adj_mat, node_labels, node_roles, obj_boxes, part_boxes, human_boxes, img_id, img_name, human_num, part_num, obj_num, obj_classes, part_classes) in enumerate(val_loader):
        edge_features = utils.to_variable(edge_features, args.cuda, non_blocking=True)
        node_features = utils.to_variable(node_features, args.cuda, non_blocking=True)
        adj_mat = utils.to_variable(adj_mat, args.cuda, non_blocking=True)
        node_labels = utils.to_variable(node_labels, args.cuda, non_blocking=True)
        node_roles = utils.to_variable(node_roles, args.cuda, non_blocking=True)

        pred_adj_mat, pred_node_labels, pred_node_roles = model(edge_features, node_features, part_human_id, adj_mat, node_labels, node_roles, human_num, part_num, obj_num, part_classes, args)
        pred_node_label_lifted, det_indices, loss = loss_fn(pred_adj_mat, adj_mat, pred_node_labels, node_labels, pred_node_roles, node_roles, human_num, part_num, obj_num, part_human_id, mse_loss, multi_label_loss)
        append_results(pred_adj_mat, adj_mat, pred_node_labels, node_labels, pred_node_roles, node_roles, part_human_id, img_id,
                           obj_boxes, part_boxes, human_boxes, human_num, part_num, obj_num, obj_classes, part_classes, all_results)

        # Log
        if len(det_indices) > 0:
            losses.update(loss.item(), len(det_indices))
            y_true, y_score = evaluation(det_indices, pred_node_label_lifted, node_labels, y_true, y_score, test=test)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_interval == 0:
            mean_avg_prec = compute_mean_avg_prec(y_true, y_score)
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Mean Avg Precision {mean_avg_prec:.4f} ({mean_avg_prec:.4f})\t'
                  'Detected HOIs {y_shape}'
                  .format(i, len(val_loader), batch_time=batch_time,
                          loss=losses, mean_avg_prec=mean_avg_prec, y_shape=y_true.shape))

        if args.debug and i == 9:
            break

    mean_avg_prec = compute_mean_avg_prec(y_true, y_score)
    if test:
        vcoco_evaluation(args, vcocoeval, 'test', all_results)
        if args.visualize:
            utils.visualize_vcoco_result(args, result_folder, all_results)
    else:
        pass
        # vcoco_evaluation(args, vcocoeval, 'val', all_results)

    print(' * Average Mean Precision {mean_avg_prec:.4f}; Average Loss {loss.avg:.4f}'
          .format(mean_avg_prec=mean_avg_prec, loss=losses))

    if logger is not None:
        logger.log_value('test_epoch_loss', losses.avg)
        logger.log_value('train_epoch_map', mean_avg_prec)

    return 1.0 - mean_avg_prec
Exemple #3
0
with torch.no_grad():
    total = 0
    correct = 0
    accuracy = 0

    model.eval()
    for i, ds in enumerate(test_loader):

        if type(model) is DataParallel:
            model = model.cuda()
        else:
            model = model

        labels, echoes = list(
            map(lambda x: to_variable(x, is_cuda=is_cuda), ds))

        echoes = echoes.float()
        outputs = model(echoes)
        _, preds = torch.max(outputs.data, 1)

        accuracies_per_class = plot_confusion_matrix(
            labels.cpu(),
            preds.cpu(),
            labels_in=[0, 1, 2, 5, 6, 7, 3, 4, 8, 9, 10, 11],
            epoch=38)
        print('bat', (accuracies_per_class[0] + accuracies_per_class[1] +
                      accuracies_per_class[2] + accuracies_per_class[5] +
                      accuracies_per_class[6] + accuracies_per_class[7]) / 6)
        print('notbat',
              (accuracies_per_class[3] + accuracies_per_class[4] +
Exemple #4
0
    def decode(self, keys, values):
        """
        :param keys:
        :param values:
        :return: Returns the best decoded sentence
        """
        bs = 1  # batch_size for decoding
        output = []
        raw_preds = []

        for _ in range(100):
            hidden_states = []
            raw_pred = None
            raw_out = []
            # Initial context
            query = self.linear(
                self.lstm_cells[2].h0)  # bs * 256, This is the query
            attn = torch.bmm(query.unsqueeze(1),
                             keys.permute(1, 2, 0))  # bs * 1 * seq_len/8
            attn = F.softmax(attn, dim=2)
            context = torch.bmm(attn, values.permute(1, 0, 2)).squeeze(1)

            h = self.embed(to_variable(torch.zeros(bs).long(
            )))  # Start token provided for generating the sentence
            for i in range(self.max_decoding_length):
                h = torch.cat((h, context), dim=1)
                for j, lstm in enumerate(self.lstm_cells):
                    if i == 0:
                        h_x_0, c_x_0 = lstm(h, lstm.h0, lstm.c0)  # bs * 512
                        hidden_states.append((h_x_0, c_x_0))
                    else:
                        h_x_0, c_x_0 = hidden_states[j]
                        hidden_states[j] = lstm(h, h_x_0, c_x_0)
                    h = hidden_states[j][0]

                query = self.linear(h)  # bs * 2048, This is the query
                attn = torch.bmm(query.unsqueeze(1),
                                 keys.permute(1, 2, 0))  # bs * 1 * seq_len/8
                # attn.data.masked_fill_((1 - mask).unsqueeze(1), -float('inf'))
                attn = F.softmax(attn, dim=2)
                context = torch.bmm(attn,
                                    values.permute(1, 0,
                                                   2)).squeeze(1)  # bs * 256
                h = torch.cat((h, context), dim=1)

                # At this point, h is the embed from the 2 lstm cells. Passing it through the projection layers
                h = self.projection_layer1(h)
                h = self.non_linear(h)
                h = self.projection_layer2(h)
                lsm = self.softmax(h)
                if self.is_stochastic > 0:
                    gumbel = torch.autograd.Variable(
                        self.sample_gumbel(shape=h.size(), out=h.data.new()))
                    h += gumbel
                # TODO: Do beam search later

                h = torch.max(h, dim=1)[1]
                raw_out.append(h.data.cpu().numpy()[0])
                if raw_pred is None:
                    raw_pred = lsm
                else:
                    raw_pred = torch.cat((raw_pred, lsm), dim=0)

                if h.data.cpu().numpy() == 0:
                    break

                # Primer for next character generation
                h = self.embed(h)
            output.append(raw_out)
            raw_preds.append(raw_pred)
        return output, raw_preds
def train(args, train_loader, model, mse_loss, multi_label_loss, optimizer, epoch, vcocoeval, logger):
    batch_time = logutil.AverageMeter()
    data_time = logutil.AverageMeter()
    losses = logutil.AverageMeter()

    y_true = np.empty((0, action_class_num))
    y_score = np.empty((0, action_class_num))
    all_results = list()

    # switch to train mode
    model.train()

    end_time = time.time()
    for i, (edge_features, node_features, part_human_id, adj_mat, node_labels, node_roles, obj_boxes, part_boxes, human_boxes, img_id, img_name, human_num, part_num, obj_num, obj_classes, part_classes) in enumerate(train_loader):
        data_time.update(time.time() - end_time)
        optimizer.zero_grad()

        edge_features = utils.to_variable(edge_features, args.cuda, non_blocking=True)
        node_features = utils.to_variable(node_features, args.cuda, non_blocking=True)
        adj_mat = utils.to_variable(adj_mat, args.cuda, non_blocking=True)
        node_labels = utils.to_variable(node_labels, args.cuda, non_blocking=True)
        node_roles = utils.to_variable(node_roles, args.cuda, non_blocking=True)

        pred_adj_mat, pred_node_labels, pred_node_roles = model(edge_features, node_features, part_human_id, adj_mat, node_labels, node_roles, human_num, part_num, obj_num, part_classes, args)
        pred_node_label_lifted, det_indices, loss = loss_fn(pred_adj_mat, adj_mat, pred_node_labels, node_labels, pred_node_roles, node_roles, human_num, part_num, obj_num, part_human_id, mse_loss, multi_label_loss)
        append_results(pred_adj_mat, adj_mat, pred_node_labels, node_labels, pred_node_roles, node_roles, part_human_id, img_id,
                           obj_boxes, part_boxes, human_boxes, human_num, part_num, obj_num, obj_classes, part_classes, all_results)

        # Log and back propagate
        if len(det_indices) > 0:
            y_true, y_score = evaluation(det_indices, pred_node_label_lifted, node_labels, y_true, y_score)

        if not isinstance(loss, int):
            losses.update(loss.item(), edge_features.size()[0])
            loss.backward()
            optimizer.step()

        # Measure elapsed time
        batch_time.update(time.time() - end_time)
        end_time = time.time()

        if i % args.log_interval == 0:
            mean_avg_prec = compute_mean_avg_prec(y_true, y_score)
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Mean Avg Precision {mean_avg_prec:.4f} ({mean_avg_prec:.4f})\t'
                  'Detected HOIs {y_shape}'
                  .format(epoch, i, len(train_loader), batch_time=batch_time,
                          data_time=data_time, loss=losses, mean_avg_prec=mean_avg_prec, y_shape=y_true.shape))
    
        if args.debug and i == 30:
            break

    mean_avg_prec = compute_mean_avg_prec(y_true, y_score)
    # vcoco_evaluation(args, vcocoeval, 'train', all_results)

    if logger is not None:
        logger.log_value('train_epoch_loss', losses.avg)
        logger.log_value('train_epoch_map', mean_avg_prec)

    print('Epoch: [{0}] Avg Mean Precision {map:.4f}; Average Loss {loss.avg:.4f}; Avg Time x Batch {b_time.avg:.4f}'
          .format(epoch, map=mean_avg_prec, loss=losses, b_time=batch_time))
Exemple #6
0
def gen_test_result(args, test_loader, model, mse_loss, multi_label_loss,
                    img_index):
    filtered_hoi = dict()
    # switch to evaluate mode
    model.eval()
    end = time.time()
    total_idx = 0
    obj_stats = dict()

    filtered_hoi = dict()

    for i, (edge_features, node_features, adj_mat, node_labels, sequence_ids,
            det_classes, det_boxes, human_nums,
            obj_nums) in enumerate(test_loader):
        edge_features = utils.to_variable(edge_features, args.cuda)
        node_features = utils.to_variable(node_features, args.cuda)
        adj_mat = utils.to_variable(adj_mat, args.cuda)
        node_labels = utils.to_variable(node_labels, args.cuda)
        if sequence_ids[0] is 'HICO_test2015_00000396':
            break

        pred_adj_mat, pred_node_labels = model(edge_features, node_features,
                                               adj_mat, node_labels,
                                               human_nums, obj_nums, args)

        for batch_i in range(pred_adj_mat.size()[0]):
            sequence_id = sequence_ids[batch_i]
            hois_test = get_indices(pred_adj_mat[batch_i],
                                    pred_node_labels[batch_i],
                                    human_nums[batch_i], obj_nums[batch_i],
                                    det_classes[batch_i], det_boxes[batch_i])
            hois_gt = get_indices(adj_mat[batch_i], node_labels[batch_i],
                                  human_nums[batch_i], obj_nums[batch_i],
                                  det_classes[batch_i], det_boxes[batch_i])
            for hoi in hois_test:
                _, o_idx, a_idx, info, _ = hoi
                if o_idx not in filtered_hoi.keys():
                    filtered_hoi[o_idx] = dict()
                if sequence_id not in filtered_hoi[o_idx].keys():
                    filtered_hoi[o_idx][sequence_id] = list()
                filtered_hoi[o_idx][sequence_id].append(info)

        print("finished generating result from " + sequence_ids[0] + " to " +
              sequence_ids[-1])

    for obj_idx, save_info in filtered_hoi.items():
        obj_start, obj_end = metadata.obj_hoi_index[obj_idx]
        obj_arr = np.empty((obj_end - obj_start + 1, len(img_index)),
                           dtype=np.object)
        for row in range(obj_arr.shape[0]):
            for col in range(obj_arr.shape[1]):
                obj_arr[row][col] = []
        for id, data_info in save_info.items():
            col_idx = img_index.index(id)
            for pair in data_info:
                row_idx = pair[2]
                bbox_concat = np.concatenate((pair[0], pair[1], [pair[3]]))
                if len(obj_arr[row_idx][col_idx]) > 0:
                    obj_arr[row_idx][col_idx] = np.vstack(
                        (obj_arr[row_idx][col_idx], bbox_concat))
                else:
                    obj_arr[row_idx][col_idx] = bbox_concat
        sio.savemat(
            os.path.join(args.tmp_root, 'results', 'HICO',
                         'detections_' + str(obj_idx).zfill(2) + '.mat'),
            {'all_boxes': obj_arr})
        print('finished saving for ' + str(obj_idx))

    return
Exemple #7
0
def validate(val_loader,
             model,
             mse_loss,
             multi_label_loss,
             logger=None,
             test=False):
    if args.visualize:
        result_folder = os.path.join(args.tmp_root, 'results/HICO/detections/',
                                     'top' + str(args.vis_top_k))
        if not os.path.exists(result_folder):
            os.makedirs(result_folder)

    batch_time = logutil.AverageMeter()
    losses = logutil.AverageMeter()

    y_true = np.empty((0, action_class_num))
    y_score = np.empty((0, action_class_num))

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (edge_features, node_features, adj_mat, node_labels, sequence_ids,
            det_classes, det_boxes, human_num,
            obj_num) in enumerate(val_loader):

        edge_features = utils.to_variable(edge_features, args.cuda)
        node_features = utils.to_variable(node_features, args.cuda)
        adj_mat = utils.to_variable(adj_mat, args.cuda)
        node_labels = utils.to_variable(node_labels, args.cuda)

        pred_adj_mat, pred_node_labels = model(edge_features, node_features,
                                               adj_mat, node_labels, human_num,
                                               obj_num, args)
        det_indices, loss = loss_fn(pred_adj_mat, adj_mat, pred_node_labels,
                                    node_labels, mse_loss, multi_label_loss,
                                    human_num, obj_num)

        # Log
        if len(det_indices) > 0:
            losses.update(loss.data[0], len(det_indices))
            y_true, y_score = evaluation(det_indices,
                                         pred_node_labels,
                                         node_labels,
                                         y_true,
                                         y_score,
                                         test=test)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_interval == 0 and i > 0:
            mean_avg_prec = compute_mean_avg_prec(y_true, y_score)
            print(
                'Test: [{0}/{1}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Mean Avg Precision {mean_avg_prec:.4f} ({mean_avg_prec:.4f})\t'
                'Detected HOIs {y_shape}'.format(i,
                                                 len(val_loader),
                                                 batch_time=batch_time,
                                                 loss=losses,
                                                 mean_avg_prec=mean_avg_prec,
                                                 y_shape=y_true.shape))

    mean_avg_prec = compute_mean_avg_prec(y_true, y_score)

    print(
        ' * Average Mean Precision {mean_avg_prec:.4f}; Average Loss {loss.avg:.4f}'
        .format(mean_avg_prec=mean_avg_prec, loss=losses))

    if logger is not None:
        logger.log_value('test_epoch_loss', losses.avg)
        logger.log_value('train_epoch_map', mean_avg_prec)

    return 1.0 - mean_avg_prec
Exemple #8
0
def train(train_loader, model, mse_loss, multi_label_loss, optimizer, epoch,
          logger):
    batch_time = logutil.AverageMeter()
    data_time = logutil.AverageMeter()
    losses = logutil.AverageMeter()

    y_true = np.empty((0, action_class_num))
    y_score = np.empty((0, action_class_num))

    # switch to train mode
    model.train()

    end_time = time.time()

    for i, (edge_features, node_features, adj_mat, node_labels, sequence_ids,
            det_classes, det_boxes, human_num,
            obj_num) in enumerate(train_loader):

        data_time.update(time.time() - end_time)
        optimizer.zero_grad()

        edge_features = utils.to_variable(edge_features, args.cuda)
        node_features = utils.to_variable(node_features, args.cuda)
        adj_mat = utils.to_variable(adj_mat, args.cuda)
        node_labels = utils.to_variable(node_labels, args.cuda)

        pred_adj_mat, pred_node_labels = model(edge_features, node_features,
                                               adj_mat, node_labels, human_num,
                                               obj_num, args)
        det_indices, loss = loss_fn(pred_adj_mat, adj_mat, pred_node_labels,
                                    node_labels, mse_loss, multi_label_loss,
                                    human_num, obj_num)

        # Log and back propagate
        if len(det_indices) > 0:
            y_true, y_score = evaluation(det_indices, pred_node_labels,
                                         node_labels, y_true, y_score)

        losses.update(loss.data[0], edge_features.size()[0])
        loss.backward()
        optimizer.step()

        # Measure elapsed time
        batch_time.update(time.time() - end_time)
        end_time = time.time()

        if i % args.log_interval == 0:
            mean_avg_prec = compute_mean_avg_prec(y_true, y_score)
            print(
                'Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Mean Avg Precision {mean_avg_prec:.4f} ({mean_avg_prec:.4f})\t'
                'Detected HOIs {y_shape}'.format(epoch,
                                                 i,
                                                 len(train_loader),
                                                 batch_time=batch_time,
                                                 data_time=data_time,
                                                 loss=losses,
                                                 mean_avg_prec=mean_avg_prec,
                                                 y_shape=y_true.shape))

    mean_avg_prec = compute_mean_avg_prec(y_true, y_score)

    if logger is not None:
        logger.log_value('train_epoch_loss', losses.avg)
        logger.log_value('train_epoch_map', mean_avg_prec)

    print(
        'Epoch: [{0}] Avg Mean Precision {map:.4f}; Average Loss {loss.avg:.4f}; Avg Time x Batch {b_time.avg:.4f}'
        .format(epoch, map=mean_avg_prec, loss=losses, b_time=batch_time))
Exemple #9
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    start_epoch = 0
    train_image_dataset = image_preprocessing(opt.dataset, 'train')
    data_loader = DataLoader(train_image_dataset, batch_size=opt.batch_size,
                            shuffle=True, num_workers=opt.num_workers)
    criterion = least_squares
    euclidean_l1 = nn.L1Loss()

    G = Generator(ResidualBlock, layer_count=9)
    F = Generator(ResidualBlock, layer_count=9)
    Dx = Discriminator()
    Dy = Discriminator()

    G_optimizer = optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    F_optimizer = optim.Adam(F.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dx_optimizer = optim.Adam(Dx.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dy_optimizer = optim.Adam(Dy.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

    if torch.cuda.is_available():
        G = nn.DataParallel(G)
        F = nn.DataParallel(F)
        Dx = nn.DataParallel(Dx)
        Dy = nn.DataParallel(Dy)

        G = G.cuda()
        F = F.cuda()
        Dx = Dx.cuda()
        Dy = Dy.cuda()
    
    if opt.checkpoint is not None:
        G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer, start_epoch = load_ckp(opt.checkpoint, G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer)

    print('[Start] : Cycle GAN Training')

    logger = Logger(opt.epochs, len(data_loader), image_step=10)

    for epoch in range(opt.epochs):
        epoch = epoch + start_epoch + 1
        print("Epoch[{epoch}] : Start".format(epoch=epoch))
        
        for step, data in enumerate(data_loader):
            real_A = to_variable(data['A'])
            real_B = to_variable(data['B'])

            fake_B = G(real_A)
            fake_A = F(real_B)

            # Train Dx
            Dx_optimizer.zero_grad()

            Dx_real = Dx(real_A)
            Dx_fake = Dx(fake_A)

            Dx_loss = patch_loss(criterion, Dx_real, True) + patch_loss(criterion, Dx_fake, 0)

            Dx_loss.backward(retain_graph=True)
            Dx_optimizer.step()

            # Train Dy
            Dy_optimizer.zero_grad()

            Dy_real = Dy(real_B)
            Dy_fake = Dy(fake_B)

            Dy_loss = patch_loss(criterion, Dy_real, True) + patch_loss(criterion, Dy_fake, 0)

            Dy_loss.backward(retain_graph=True)
            Dy_optimizer.step()

            # Train G
            G_optimizer.zero_grad()

            Dy_fake = Dy(fake_B)

            G_loss = patch_loss(criterion, Dy_fake, True)

            # Train F
            F_optimizer.zero_grad()

            Dx_fake = Dx(fake_A)

            F_loss = patch_loss(criterion, Dx_fake, True)

            # identity loss
            loss_identity = euclidean_l1(real_A, fake_A) + euclidean_l1(real_B, fake_B)

            # cycle consistency
            loss_cycle = euclidean_l1(F(fake_B), real_A) + euclidean_l1(G(fake_A), real_B)

            # Optimize G & F
            loss = G_loss + F_loss + opt.lamda * loss_cycle + opt.lamda * loss_identity * (0.5)

            loss.backward()
            G_optimizer.step()
            F_optimizer.step()

            if (step + 1 ) % opt.save_step == 0:
                print("Epoch[{epoch}]| Step [{now}/{total}]| Dx Loss: {Dx_loss}, Dy_Loss: {Dy_loss}, G_Loss: {G_loss}, F_Loss: {F_loss}".format(
                    epoch=epoch, now=step + 1, total=len(data_loader), Dx_loss=Dx_loss.item(), Dy_loss=Dy_loss,
                    G_loss=G_loss.item(), F_loss=F_loss.item()))
                batch_image = torch.cat((torch.cat((real_A, real_B), 3), torch.cat((fake_A, fake_B), 3)), 2)

                torchvision.utils.save_image(denorm(batch_image[0]), opt.training_result + 'result_{result_name}_ep{epoch}_{step}.jpg'.format(result_name=opt.result_name,epoch=epoch, step=(step + 1) * opt.batch_size))
            
            # http://localhost:8097
            logger.log(
                losses={
                    'loss_G': G_loss,
                    'loss_F': F_loss,
                    'loss_identity': loss_identity,
                    'loss_cycle': loss_cycle,
                    'total_G_loss': loss,
                    'loss_Dx': Dx_loss,
                    'loss_Dy': Dy_loss,
                    'total_D_loss': (Dx_loss + Dy_loss),
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_A,
                    'fake_ B': fake_B,
                },
            )


        torch.save({
            'epoch': epoch,
            'G_model': G.state_dict(),
            'G_optimizer': G_optimizer.state_dict(),
            'F_model': F.state_dict(),
            'F_optimizer': F_optimizer.state_dict(),
            'Dx_model': Dx.state_dict(),
            'Dx_optimizer': Dx_optimizer.state_dict(),
            'Dy_model': Dy.state_dict(),
            'Dy_optimizer': Dy_optimizer.state_dict(),
        }, opt.save_model + 'model_{result_name}_CycleGAN_ep{epoch}.ckp'.format(result_name=opt.result_name, epoch=epoch))
def train(args, discriminator, generator, criterion, optim_dis, optim_gen,
          data_loader):

    discriminator.train()
    generator.train()

    batch = args.batch

    y_real = utils.to_cuda(Variable(torch.ones(batch, 1)))
    y_fake = utils.to_cuda(Variable(torch.zeros(batch, 1)))

    loss_dis, loss_gen = 0, 0

    datas = tqdm(data_loader)
    for idx_batch, (real_images, _) in enumerate(datas):
        datas.set_description('Processing DataLoader %d' % idx_batch)

        # 一番最後、バッチサイズに満たない場合は無視する
        if real_images.size()[0] != batch: break

        real_images = utils.to_variable(real_images)
        z = utils.to_variable(torch.rand((batch, args.z_dim)))

        # Discriminatorの更新
        optim_dis.zero_grad()

        # Discriminatorにとって本物画像の認識結果は1(本物)に近いほどよい
        # E[log(D(x))]
        D_real = discriminator(real_images)
        loss_real = criterion(D_real, y_real)

        # DiscriminatorにとってGeneratorが生成した偽物画像の認識結果は0(偽物)に近いほどよい
        # E[log(1 - D(G(z)))]
        # fake_imagesを通じて勾配がGに伝わらないようにdetach()して止める
        fake_images = generator(z)
        D_fake = discriminator(fake_images.detach())
        loss_fake = criterion(D_fake, y_fake)  # size([128,1])

        # 2つのlossの和を最小化する
        loss_dis_batch = loss_real + loss_fake
        loss_dis_batch.backward()
        optim_dis.step()  # これでGのパラメータは更新されない!
        loss_dis += float(loss_dis_batch.data)

        # Generatorの更新
        z = utils.to_variable(torch.rand((batch, args.z_dim)))
        optim_gen.zero_grad()

        # GeneratorにとってGeneratorが生成した画像の認識結果は1(本物)に近いほどよい
        # E[log(D(G(z)))
        fake_images = generator(z)
        D_fake = discriminator(fake_images)
        loss_gen_batch = criterion(D_fake, y_real)
        loss_gen_batch.backward()
        optim_gen.step()
        loss_gen += float(loss_gen_batch.data)

        #sys.stdout.write('\033[2K\033[G LOSS -- gen: {:.4f} dis: {:.4f}'.format(float(loss_gen_batch.data), float(loss_dis_batch.data)))
        #sys.stdout.flush()

    loss_dis /= len(data_loader)
    loss_gen /= len(data_loader)

    return loss_dis, loss_gen