Exemplo n.º 1
0
    def _train_one_epoch(self, epoch):
        self.model.train()
        loss_meter = AverageMeter()
        time_meter = TimeMeter()
        for bid, (video, video_mask, words, word_mask,
                  label, scores, scores_mask, id2pos, node_mask, adj_mat) in enumerate(self.train_loader, 1):
            self.optimizer.zero_grad()

            model_input = {
                'frames': video.cuda(),
                'frame_mask': video_mask.cuda(), 'words': words.cuda(), 'word_mask': word_mask.cuda(),
                'label': scores.cuda(), 'label_mask': scores_mask.cuda(), 'gt': label.cuda(),
                'node_pos': id2pos.cuda(), 'node_mask': node_mask.cuda(), 'adj_mat': adj_mat.cuda()
            }

            predict_boxes, loss, _, _, _ = self.model(**model_input)
            loss = torch.mean(loss)
            self.optimizer.backward(loss)

            self.optimizer.step()
            self.num_updates += 1
            curr_lr = self.lr_scheduler.step_update(self.num_updates)

            loss_meter.update(loss.item())
            time_meter.update()

            if bid % self.args.display_n_batches == 0:
                logging.info('Epoch %d, Batch %d, loss = %.4f, lr = %.5f, %.3f seconds/batch' % (
                    epoch, bid, loss_meter.avg, curr_lr, 1.0 / time_meter.avg
                ))
                loss_meter.reset()
Exemplo n.º 2
0
    def eval_save(self):
        data_loaders = [self.test_loader]
        meters = collections.defaultdict(lambda: AverageMeter())
        time_meter = TimeMeter()
        f = open('./our.txt','w')
        self.model.eval()
        with torch.no_grad():
            for data_loader in data_loaders:
                for bid, (video, video_mask, words, word_mask,
                          label, scores, scores_mask, id2pos, node_mask, adj_mat) in enumerate(data_loader, 1):
                    self.optimizer.zero_grad()

                    model_input = {
                        'frames': video.cuda(),
                        'frame_mask': video_mask.cuda(), 'words': words.cuda(), 'word_mask': word_mask.cuda(),
                        'label': scores.cuda(), 'label_mask': scores_mask.cuda(), 'gt': label.cuda(),
                        'node_pos': id2pos.cuda(), 'node_mask': node_mask.cuda(), 'adj_mat': adj_mat.cuda()
                    }

                    predict_boxes, loss, _, a1, a2 = self.model(**model_input)
                    loss = torch.mean(loss)
                    time_meter.update()
                    if bid % self.args.display_n_batches == 0:
                        logging.info('%.3f seconds/batch' % (
                            1.0 / time_meter.avg
                        ))
                    meters['loss'].update(loss.item())
                    a1, a2 = a1.cpu().numpy(), a2.cpu().numpy()
                    np.save('a1.npy',a1)
                    np.save('a2.npy',a2) 
                    video_mask = video_mask.cpu().numpy()
                    gt_boxes = model_input['gt'].cpu().numpy()
                    predict_boxes = np.round(predict_boxes.cpu().numpy()).astype(np.int32)
                    gt_starts, gt_ends = gt_boxes[:, 0], gt_boxes[:, 1]
                    predict_starts, predict_ends = predict_boxes[:, 0], predict_boxes[:, 1]
                    predict_starts[predict_starts < 0] = 0
                    seq_len = np.sum(video_mask, -1)
                    predict_ends[predict_ends >= seq_len] = seq_len[predict_ends >= seq_len] - 1
                    IoUs = criteria.calculate_IoU_batch((predict_starts, predict_ends),
                                                        (gt_starts, gt_ends))
                    for kk in range(predict_starts.shape[0]):
                        f.write('IoU: '+str(IoUs[kk])+' start: '+str(predict_starts[kk])+' ends: '+str(predict_ends[kk])+' gt: '+str(gt_starts[kk])+' '+str(gt_ends[kk])+'\n')
                    meters['mIoU'].update(np.mean(IoUs), IoUs.shape[0])
                    for i in range(1, 10, 2):
                        meters['IoU@0.%d' % i].update(np.mean(IoUs >= (i / 10)), IoUs.shape[0])
                if data_loaders.index(data_loader) == 0:
                    print('--------val')
                else:
                    print('--------test')
                print('| ', end='')
                for key, value in meters.items():
                    print('{}, {:.4f}'.format(key, value.avg), end=' | ')
                    meters[key].reset()
                print()
    def _train_one_epoch(self, epoch, **kwargs):
        self.model.train()

        def print_log():
            msg = 'Epoch {}, Batch {}, lr = {:.5f}, '.format(
                epoch, bid, curr_lr)
            for k, v in loss_meter.items():
                msg += '{} = {:.4f}, '.format(k, v.avg)
                v.reset()
            msg += '{:.3f} seconds/batch'.format(1.0 / time_meter.avg)
            logging.info(msg)

        display_n_batches, bid = 50, 0
        time_meter = TimeMeter()
        loss_meter = collections.defaultdict(lambda: AverageMeter())

        rewards = torch.from_numpy(np.asarray(
            self.args['train']['rewards'])).cuda()
        num_proposals = rewards.size(0)

        random_p = 0.5 * np.exp(-self.num_updates / 2000)

        for bid, batch in enumerate(self.train_loader, 1):
            self.optimizer.zero_grad()
            net_input = move_to_cuda(batch['net_input'])
            tau = 0.65
            output = self.model(**net_input,
                                num_proposals=num_proposals,
                                random_p=random_p,
                                tau=tau)
            # for k, v in output.items():
            #     print(k, v.size())
            loss, loss_dict = weakly_supervised_loss(**output, rewards=rewards)
            # backward
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)

            # update
            self.optimizer.step()
            self.num_updates += 1
            curr_lr = self.lr_scheduler.step_update(self.num_updates)
            time_meter.update()
            for k, v in loss_dict.items():
                loss_meter[k].update(v)

            if bid % display_n_batches == 0:
                print_log()

        if bid % display_n_batches != 0:
            print_log()