示例#1
0
def train():
    net.train()
    current_lr = config['learning_rate']
    print('Loading Dataset...')

    dataset = MOTTrainDataset(args.mot_root,
                         SSJAugmentation(
                             sst_dim, means
                         )
                         )

    epoch_size = len(dataset) // args.batch_size
    print('Training SSJ on', dataset.dataset_name)
    step_index = 0

    batch_iterator = None

    data_loader = data.DataLoader(dataset, batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=collate_fn,
                                  pin_memory=False)

    for iteration in range(args.start_iter, max_iter):
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(data_loader)
            all_epoch_loss = []

        if iteration in stepvalues:
            step_index += 1
            current_lr = adjust_learning_rate(optimizer, args.gamma, step_index)

        # load train data
        img_pre, img_next, boxes_pre, boxes_next, labels, valid_pre, valid_next=\
            next(batch_iterator)

        if args.cuda:
            img_pre = Variable(img_pre.cuda())
            img_next = Variable(img_next.cuda())
            boxes_pre = Variable(boxes_pre.cuda())
            boxes_next = Variable(boxes_next.cuda())
            valid_pre = Variable(valid_pre.cuda(), volatile=True)
            valid_next = Variable(valid_next.cuda(), volatile=True)
            labels = Variable(labels.cuda(), volatile=True)

        else:
            img_pre = Variable(img_pre)
            img_next = Variable(img_next)
            boxes_pre = Variable(boxes_pre)
            boxes_next = Variable(boxes_next)
            valid_pre = Variable(valid_pre)
            valid_next = Variable(valid_next)
            labels = Variable(labels, volatile=True)


        # forward
        t0 = time.time()
        out = net(img_pre, img_next, boxes_pre, boxes_next, valid_pre, valid_next)

        optimizer.zero_grad()
        loss_pre, loss_next, loss_similarity, loss, accuracy_pre, accuracy_next, accuracy, predict_indexes = criterion(out, labels, valid_pre, valid_next)

        loss.backward()
        optimizer.step()
        t1 = time.time()

        all_epoch_loss += [loss.data.cpu()]

        if iteration % 10 == 0:
            print('Timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ', ' + repr(epoch_size) + ' || epoch: %.4f ' % (iteration/(float)(epoch_size)) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')

        if args.tensorboard:
            if len(all_epoch_loss) > 30:
                writer.add_scalar('data/epoch_loss', float(np.mean(all_epoch_loss)), iteration)
            writer.add_scalar('data/learning_rate', current_lr, iteration)

            writer.add_scalar('loss/loss', loss.data.cpu(), iteration)
            writer.add_scalar('loss/loss_pre', loss_pre.data.cpu(), iteration)
            writer.add_scalar('loss/loss_next', loss_next.data.cpu(), iteration)
            writer.add_scalar('loss/loss_similarity', loss_similarity.data.cpu(), iteration)

            writer.add_scalar('accuracy/accuracy', accuracy.data.cpu(), iteration)
            writer.add_scalar('accuracy/accuracy_pre', accuracy_pre.data.cpu(), iteration)
            writer.add_scalar('accuracy/accuracy_next', accuracy_next.data.cpu(), iteration)

            # add weights
            if iteration % 1000 == 0:
                for name, param in net.named_parameters():
                    writer.add_histogram(name, param.clone().cpu().data.numpy(), iteration)

            # add images
            if args.send_images and iteration % 1000 == 0:
                result_image = show_batch_circle_image(img_pre, img_next, boxes_pre, boxes_next, valid_pre, valid_next, predict_indexes, iteration)

                writer.add_image('WithLabel/ImageResult', vutils.make_grid(result_image, nrow=2, normalize=True, scale_each=True), iteration)

        if iteration % save_weights_iteration == 0:
            print('Saving state, iter:', iteration)
            torch.save(sst_net.state_dict(),
                       os.path.join(
                           args.save_folder,
                           'sst300_0712_' + repr(iteration) + '.pth'))

    torch.save(sst_net.state_dict(), args.save_folder + '' + args.version + '.pth')
示例#2
0
        continue

    if config['cuda']:
        img_pre = Variable(img_pre.cuda())
        img_next = Variable(img_next.cuda())
        boxes_pre = Variable(boxes_pre.cuda())
        boxes_next = Variable(boxes_next.cuda())
        valid_pre = Variable(valid_pre.cuda(), volatile=True)
        valid_next = Variable(valid_next.cuda(), volatile=True)
        labels = Variable(labels.cuda(), volatile=True)

    else:
        img_pre = Variable(img_pre)
        img_next = Variable(img_next)
        boxes_pre = Variable(boxes_pre)
        boxes_next = Variable(boxes_next)
        valid_pre = Variable(valid_pre)
        valid_next = Variable(valid_next)
        labels = Variable(labels, volatile=True)

    out = sst(img_pre, img_next, boxes_pre, boxes_next, valid_pre, valid_next)
    loss_pre, loss_next, loss_similarity, loss, accuracy_pre, accuracy_next, accuracy, predict_indexes = criterion(
        out, labels, valid_pre, valid_next)

    result_image = show_batch_circle_image(img_pre, img_next, boxes_pre,
                                           boxes_next, valid_pre, valid_next,
                                           predict_indexes, i)
    result_image = result_image[0, :].permute(1, 2, 0).cpu().numpy()
    cv2.imwrite(image_format.format(i), result_image)
    print(i)
def train():

    net.train()
    current_lr = config['learning_rate']
    print('Loading Dataset...')

    dataset = TrainDataset(
        args.nuscenes_data_root,
        SSJAugmentation(
            config['image_size'],
            means,
        ),
    )
    print('length = ', len(dataset))
    epoch_size = len(dataset) // args.batch_size
    step_index = 0

    batch_iterator = None
    batch_size = 1
    data_loader = data.DataLoader(
        dataset,
        batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        collate_fn=collate_fn,
        pin_memory=False,
    )
    if args.Joint:
        nusc = NuScenes(
            version='v1.0-trainval',
            verbose=True,
            dataroot=args.nuscenes_root,
        )
    for iteration in range(0, 380000):
        print(iteration)
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(data_loader)
            all_epoch_loss = []
        print('pass 0')
        if iteration in stepvalues:
            step_index += 1
            current_lr = current_lr * 0.5
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

        # load train data
        img_pre, img_next, boxes_pre, boxes_next, labels, valid_pre, valid_next, current_poses, next_poses, pre_bbox, next_bbox, img_org_pre, img_org_next, current_tokens, next_tokens =\
            next(batch_iterator)

        if args.Joint:
            current_token = current_tokens[0]
            next_token = next_tokens[0]
            first_token = current_tokens[1]

            current_cam_record = nusc.get('sample_data', current_token)
            current_cam_path = nusc.get_sample_data_path(current_token)
            current_cam_path, boxes, current_camera_intrinsic = nusc.get_sample_data(
                current_cam_record['token'], )
            current_cs_record = nusc.get(
                'calibrated_sensor',
                current_cam_record['calibrated_sensor_token'],
            )
            current_poserecord = nusc.get(
                'ego_pose',
                current_cam_record['ego_pose_token'],
            )

            next_cam_record = nusc.get('sample_data', next_token)
            next_cam_path = nusc.get_sample_data_path(next_token)
            next_cam_path, boxes, next_camera_intrinsic = nusc.get_sample_data(
                next_cam_record['token'], )
            next_cs_record = nusc.get(
                'calibrated_sensor',
                next_cam_record['calibrated_sensor_token'],
            )
            next_poserecord = nusc.get(
                'ego_pose',
                next_cam_record['ego_pose_token'],
            )

            first_cam_record = nusc.get('sample_data', first_token)
            first_cam_path = nusc.get_sample_data_path(first_token)
            first_cam_path, boxes, first_camera_intrinsic = nusc.get_sample_data(
                first_cam_record['token'], )
            first_cs_record = nusc.get(
                'calibrated_sensor',
                first_cam_record['calibrated_sensor_token'],
            )
            first_poserecord = nusc.get(
                'ego_pose',
                first_cam_record['ego_pose_token'],
            )

        if args.cuda:
            img_pre = (img_pre.cuda())
            img_next = (img_next.cuda())
            boxes_pre = (boxes_pre.cuda())
            boxes_next = (boxes_next.cuda())
            valid_pre = (valid_pre.cuda())
            valid_next = (valid_next.cuda())
            labels = (labels.cuda())
            current_poses = (current_poses.cuda())
            next_poses = (next_poses.cuda())
            pre_bbox = (pre_bbox.cuda())
            next_bbox = (next_bbox.cuda())
            img_org_pre = (img_org_pre.cuda())
            img_org_next = (img_org_next.cuda())

        # forward
        if args.Joint:
            out, pose_loss = net(
                args.Joint,
                img_pre,
                img_next,
                boxes_pre,
                boxes_next,
                valid_pre,
                valid_next,
                current_poses,
                next_poses,
                pre_bbox,
                next_bbox,
                img_org_pre,
                img_org_next,
                current_cs_record,
                current_poserecord,
                next_cs_record,
                next_poserecord,
                first_cs_record,
                first_poserecord,
                first_camera_intrinsic,
            )

        else:
            out = net(
                img_pre,
                img_next,
                boxes_pre,
                boxes_next,
                valid_pre,
                valid_next,
            )

        optimizer.zero_grad()
        loss_pre, loss_next, loss_similarity, loss, accuracy_pre, accuracy_next, accuracy, predict_indexes = criterion(
            out,
            labels,
            valid_pre,
            valid_next,
        )
        total_loss = loss + 0.1 * pose_loss
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        all_epoch_loss += [loss.data.cpu()]

        print(
            'iter ' + repr(iteration) + ', ' + repr(epoch_size) +
            ' || epoch: %.4f ' % (iteration / (float)(epoch_size)) +
            ' || Loss: %.4f ||' % (loss.data.item()),
            end=' ',
        )

        if iteration % 1000 == 0:

            result_image = show_batch_circle_image(
                img_pre,
                img_next,
                boxes_pre,
                boxes_next,
                valid_pre,
                valid_next,
                predict_indexes,
                iteration,
            )

        if iteration % 1000 == 0:
            print('Saving state, iter:', iteration)
            torch.save(
                net.state_dict(),
                os.path.join(
                    args.save_folder,
                    'model_' + repr(iteration) + '.pth',
                ),
            )