def adnet_train_sl(args, opts):

    if torch.cuda.is_available():
        if args.cuda:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        if not args.cuda:
            print(
                "WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    if args.visualize:
        writer = SummaryWriter(log_dir=os.path.join('tensorboardx_log', args.save_file))

    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    net, domain_specific_nets = adnet(opts=opts, trained_file=args.resume, multidomain=args.multidomain)

    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True

        net = net.cuda()

    if args.cuda:
        optimizer = optim.SGD([
            {'params': net.module.base_network.parameters(), 'lr': 1e-4},
            {'params': net.module.fc4_5.parameters()},
            {'params': net.module.fc6.parameters()},
            {'params': net.module.fc7.parameters()}],  # as action dynamic is zero, it doesn't matter
            lr=1e-3, momentum=opts['train']['momentum'], weight_decay=opts['train']['weightDecay'])
    else:
        optimizer = optim.SGD([
            {'params': net.base_network.parameters(), 'lr': 1e-4},
            {'params': net.fc4_5.parameters()},
            {'params': net.fc6.parameters()},
            {'params': net.fc7.parameters()}],
            lr=1e-3, momentum=opts['train']['momentum'], weight_decay=opts['train']['weightDecay'])

    if args.resume:
        # net.load_weights(args.resume)
        checkpoint = torch.load(args.resume)

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    net.train()


    if not args.resume:
        print('Initializing weights...')

        if args.cuda:
            scal = torch.Tensor([0.01])
            # fc 4
            nn.init.normal_(net.module.fc4_5[0].weight.data)
            net.module.fc4_5[0].weight.data = net.module.fc4_5[0].weight.data * scal.expand_as(net.module.fc4_5[0].weight.data)
            net.module.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.module.fc4_5[3].weight.data)
            net.module.fc4_5[3].weight.data = net.module.fc4_5[3].weight.data * scal.expand_as(net.module.fc4_5[3].weight.data)
            net.module.fc4_5[3].bias.data.fill_(0.1)

            # fc 6
            nn.init.normal_(net.module.fc6.weight.data)
            net.module.fc6.weight.data = net.module.fc6.weight.data * scal.expand_as(net.module.fc6.weight.data)
            net.module.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.module.fc7.weight.data)
            net.module.fc7.weight.data = net.module.fc7.weight.data * scal.expand_as(net.module.fc7.weight.data)
            net.module.fc7.bias.data.fill_(0)
        else:
            scal = torch.Tensor([0.01])
            # fc 4
            nn.init.normal_(net.fc4_5[0].weight.data)
            net.fc4_5[0].weight.data = net.fc4_5[0].weight.data * scal.expand_as(net.fc4_5[0].weight.data )
            net.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.fc4_5[3].weight.data)
            net.fc4_5[3].weight.data = net.fc4_5[3].weight.data * scal.expand_as(net.fc4_5[3].weight.data)
            net.fc4_5[3].bias.data.fill_(0.1)
            # fc 6
            nn.init.normal_(net.fc6.weight.data)
            net.fc6.weight.data = net.fc6.weight.data * scal.expand_as(net.fc6.weight.data)
            net.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.fc7.weight.data)
            net.fc7.weight.data = net.fc7.weight.data * scal.expand_as(net.fc7.weight.data)
            net.fc7.bias.data.fill_(0)

    action_criterion = nn.CrossEntropyLoss()
    score_criterion = nn.CrossEntropyLoss()


    print('generating Supervised Learning dataset..')
    # dataset = SLDataset(train_videos, opts, transform=

    datasets_pos, datasets_neg = initialize_pos_neg_dataset(train_videos, opts, transform=ADNet_Augmentation(opts))
    number_domain = opts['num_videos']

    batch_iterators_pos = []
    batch_iterators_neg = []

    # calculating number of data
    len_dataset_pos = 0
    len_dataset_neg = 0
    for dataset_pos in datasets_pos:
        len_dataset_pos += len(dataset_pos)
    for dataset_neg in datasets_neg:
        len_dataset_neg += len(dataset_neg)

    epoch_size_pos = len_dataset_pos // opts['minibatch_size']
    epoch_size_neg = len_dataset_neg // opts['minibatch_size']
    epoch_size = epoch_size_pos + epoch_size_neg  # 1 epoch, how many iterations
    print("1 epoch = " + str(epoch_size) + " iterations")

    max_iter = opts['numEpoch'] * epoch_size
    print("maximum iteration = " + str(max_iter))

    data_loaders_pos = []
    data_loaders_neg = []

    for dataset_pos in datasets_pos:
        data_loaders_pos.append(data.DataLoader(dataset_pos, opts['minibatch_size'], num_workers=args.num_workers, shuffle=True, pin_memory=True))
    for dataset_neg in datasets_neg:
        data_loaders_neg.append(data.DataLoader(dataset_neg, opts['minibatch_size'], num_workers=args.num_workers, shuffle=True, pin_memory=True))

    epoch = args.start_epoch
    if epoch != 0 and args.start_iter == 0:
        start_iter = epoch * epoch_size
    else:
        start_iter = args.start_iter

    which_dataset = list(np.full(epoch_size_pos, fill_value=1))
    which_dataset.extend(np.zeros(epoch_size_neg, dtype=int))
    shuffle(which_dataset)

    which_domain = np.random.permutation(number_domain)

    action_loss = 0
    score_loss = 0

    # training loop
    for iteration in range(start_iter, max_iter):
        if args.multidomain:
            curr_domain = which_domain[iteration % len(which_domain)]
        else:
            curr_domain = 0
        # if new epoch (not including the very first iteration)
        if (iteration != start_iter) and (iteration % epoch_size == 0):
            epoch += 1
            shuffle(which_dataset)
            np.random.shuffle(which_domain)

            print('Saving state, epoch:', epoch)
            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(domain_specific_net.state_dict())

            torch.save({
                'epoch': epoch,
                'adnet_state_dict': net.state_dict(),
                'adnet_domain_specific_state_dict': domain_specific_nets,
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(args.save_folder, args.save_file) +
                       'epoch' + repr(epoch) + '.pth')

            if args.visualize:
                writer.add_scalars('data/epoch_loss', {'action_loss': action_loss / epoch_size,
                                                       'score_loss': score_loss / epoch_size,
                                                       'total': (action_loss + score_loss) / epoch_size}, global_step=epoch)

            # reset epoch loss counters
            action_loss = 0
            score_loss = 0

        # if new epoch (including the first iteration), initialize the batch iterator
        # or just resuming where batch_iterator_pos and neg haven't been initialized
        if iteration % epoch_size == 0 or len(batch_iterators_pos) == 0 or len(batch_iterators_neg) == 0:
            # create batch iterator
            for data_loader_pos in data_loaders_pos:
                batch_iterators_pos.append(iter(data_loader_pos))
            for data_loader_neg in data_loaders_neg:
                batch_iterators_neg.append(iter(data_loader_neg))

        # if not batch_iterators_pos[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_pos[curr_domain] = iter(data_loaders_pos[curr_domain])
        #
        # if not batch_iterators_neg[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_neg[curr_domain] = iter(data_loaders_neg[curr_domain])

        # load train data
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            try:
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_pos[curr_domain])
            except StopIteration:
                batch_iterators_pos[curr_domain] = iter(data_loaders_pos[curr_domain])
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_pos[curr_domain])
        else:
            try:
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_neg[curr_domain])
            except StopIteration:
                batch_iterators_neg[curr_domain] = iter(data_loaders_neg[curr_domain])
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_neg[curr_domain])

        # TODO: check if this requires grad is really false like in Variable
        if args.cuda:
            images = torch.Tensor(images.cuda())
            bbox = torch.Tensor(bbox.cuda())
            action_label = torch.Tensor(action_label.cuda())
            score_label = torch.Tensor(score_label.float().cuda())

        else:
            images = torch.Tensor(images)
            bbox = torch.Tensor(bbox)
            action_label = torch.Tensor(action_label)
            score_label = torch.Tensor(score_label)

        t0 = time.time()

        # load ADNetDomainSpecific with video index
        if args.cuda:
            net.module.load_domain_specific(domain_specific_nets[curr_domain])
        else:
            net.load_domain_specific(domain_specific_nets[curr_domain])

        # forward
        action_out, score_out = net(images)

        # backprop
        optimizer.zero_grad()
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            action_l = action_criterion(action_out, torch.max(action_label, 1)[1])
        else:
            action_l = torch.Tensor([0])
        score_l = score_criterion(score_out, score_label.long())
        loss = action_l + score_l
        loss.backward()
        optimizer.step()

        action_loss += action_l.item()
        score_loss += score_l.item()

        # save the ADNetDomainSpecific back to their module
        if args.cuda:
            domain_specific_nets[curr_domain].load_weights_from_adnet(net.module)
        else:
            domain_specific_nets[curr_domain].load_weights_from_adnet(net)

        t1 = time.time()

        if iteration % 10 == 0:
            print('Timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data.item()), end=' ')
            if args.visualize and args.send_images_to_visualization:
                random_batch_index = np.random.randint(images.size(0))
                writer.add_image('image', images.data[random_batch_index].cpu().numpy(), random_batch_index)

        if args.visualize:
            writer.add_scalars('data/iter_loss', {'action_loss': action_l.item(),
                                                  'score_loss': score_l.item(),
                                                  'total': (action_l.item() + score_l.item())}, global_step=iteration)
            # hacky fencepost solution for 0th epoch plot
            if iteration == 0:
                writer.add_scalars('data/epoch_loss', {'action_loss': action_loss,
                                                       'score_loss': score_loss,
                                                       'total': (action_loss + score_loss)}, global_step=epoch)

        if iteration % 5000 == 0:
            print('Saving state, iter:', iteration)

            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(domain_specific_net.state_dict())

            torch.save({
                'epoch': epoch,
                'adnet_state_dict': net.state_dict(),
                'adnet_domain_specific_state_dict': domain_specific_nets,
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(args.save_folder, args.save_file) +
                       repr(iteration) + '_epoch' + repr(epoch) +'.pth')

    # final save
    torch.save({
        'epoch': epoch,
        'adnet_state_dict': net.state_dict(),
        'adnet_domain_specific_state_dict': domain_specific_nets,
        'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(args.save_folder, args.save_file) + '.pth')

    return net, domain_specific_nets, train_videos
Esempio n. 2
0
def adnet_train_sl_mot(args, opts, mot, num_obj_to_track=2):
    if torch.cuda.is_available():
        if args.cuda:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        if not args.cuda:
            print(
                "WARNING: It looks like you have a CUDA device, but aren't " +
                "using CUDA.\nRun with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    if args.visualize:
        writer = SummaryWriter(
            log_dir=os.path.join('tensorboardx_log', args.save_file))

    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    net, domain_specific_nets = adnet_mot(opts=opts,
                                          trained_file=args.resume,
                                          multidomain=args.multidomain)

    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True

        net = net.cuda()

    if args.cuda:
        optimizer = optim.Adam(
            [{
                'params': net.module.base_network.parameters(),
                'lr': 1e-4
            }, {
                'params': net.module.fc4_5.parameters()
            }, {
                'params': net.module.fc6.parameters()
            }, {
                'params': net.module.fc7.parameters()
            }],  # as action dynamic is zero, it doesn't matter
            lr=1e-3,
            weight_decay=opts['train']['weightDecay'])
    else:
        optimizer = optim.SGD([{
            'params': net.base_network.parameters(),
            'lr': 1e-4
        }, {
            'params': net.fc4_5.parameters()
        }, {
            'params': net.fc6.parameters()
        }, {
            'params': net.fc7.parameters()
        }],
                              lr=1e-3,
                              momentum=opts['train']['momentum'],
                              weight_decay=opts['train']['weightDecay'])

    if args.resume:
        # net.load_weights(args.resume)
        checkpoint = torch.load(args.resume)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    net.train()

    if not args.resume:
        print('Initializing weights...')

        if args.cuda:
            norm_std = 0.01
            # fc 4
            nn.init.normal_(net.module.fc4_5[0].weight.data, std=norm_std)
            net.module.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.module.fc4_5[3].weight.data, std=norm_std)
            net.module.fc4_5[3].bias.data.fill_(0.1)

            # fc 6
            nn.init.normal_(net.module.fc6.weight.data, std=norm_std)
            net.module.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.module.fc7.weight.data, std=norm_std)
            net.module.fc7.bias.data.fill_(0)
        else:
            scal = torch.Tensor([0.01])
            # fc 4
            nn.init.normal_(net.fc4_5[0].weight.data)
            net.fc4_5[0].weight.data = net.fc4_5[
                0].weight.data * scal.expand_as(net.fc4_5[0].weight.data)
            net.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.fc4_5[3].weight.data)
            net.fc4_5[3].weight.data = net.fc4_5[
                3].weight.data * scal.expand_as(net.fc4_5[3].weight.data)
            net.fc4_5[3].bias.data.fill_(0.1)
            # fc 6
            nn.init.normal_(net.fc6.weight.data)
            net.fc6.weight.data = net.fc6.weight.data * scal.expand_as(
                net.fc6.weight.data)
            net.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.fc7.weight.data)
            net.fc7.weight.data = net.fc7.weight.data * scal.expand_as(
                net.fc7.weight.data)
            net.fc7.bias.data.fill_(0)

    action_criterion = nn.BCEWithLogitsLoss()
    score_criterion = nn.BCEWithLogitsLoss()

    print('generating Supervised Learning dataset..')
    # dataset = SLDataset(train_videos, opts, transform=
    datasets_pos, datasets_neg = initialize_pos_neg_dataset_adnet_mot(
        train_videos, opts, transform=ADNet_Augmentation(opts))
    number_domain = opts['num_videos']
    assert number_domain == len(
        datasets_pos
    ), "Num videos given in opts is incorrect! It should be {}".format(
        len(datasets_neg))

    batch_iterators_pos_train = []
    batch_iterators_neg_train = []

    action_loss_tr = 0
    score_loss_tr = 0

    # calculating number of data
    len_dataset_pos = 0
    len_dataset_neg = 0
    for dataset_pos in datasets_pos:
        len_dataset_pos += len(dataset_pos)
    for dataset_neg in datasets_neg:
        len_dataset_neg += len(dataset_neg)

    epoch_size_pos = len_dataset_pos // opts['minibatch_size']
    epoch_size_neg = len_dataset_neg // opts['minibatch_size']
    epoch_size = epoch_size_pos + epoch_size_neg  # 1 epoch, how many iterations
    print("1 epoch = " + str(epoch_size) + " iterations")

    max_iter = opts['numEpoch'] * epoch_size
    print("maximum iteration = " + str(max_iter))

    data_loaders_pos_train = []
    data_loaders_pos_val = []

    data_loaders_neg_train = []
    data_loaders_neg_val = []

    for dataset_pos in datasets_pos:
        # num_val = int(opts['val_percent'] * len(dataset_pos))
        num_val = 1
        num_train = len(dataset_pos) - num_val
        train, valid = torch.utils.data.random_split(dataset_pos,
                                                     [num_train, num_val])
        data_loaders_pos_train.append(
            data.DataLoader(train,
                            opts['minibatch_size'],
                            num_workers=2,
                            shuffle=True,
                            pin_memory=True))
        data_loaders_pos_val.append(
            data.DataLoader(valid,
                            opts['minibatch_size'],
                            num_workers=0,
                            shuffle=True,
                            pin_memory=False))
    for dataset_neg in datasets_neg:
        num_val = int(opts['val_percent'] * len(dataset_neg))
        num_train = len(dataset_neg) - num_val
        train, valid = torch.utils.data.random_split(dataset_neg,
                                                     [num_train, num_val])
        data_loaders_neg_train.append(
            data.DataLoader(train,
                            opts['minibatch_size'],
                            num_workers=1,
                            shuffle=True,
                            pin_memory=True))
        data_loaders_neg_val.append(
            data.DataLoader(valid,
                            opts['minibatch_size'],
                            num_workers=0,
                            shuffle=True,
                            pin_memory=False))

    epoch = args.start_epoch
    if epoch != 0 and args.start_iter == 0:
        start_iter = epoch * epoch_size
    else:
        start_iter = args.start_iter

    which_dataset = list(np.full(epoch_size_pos, fill_value=1))
    which_dataset.extend(np.zeros(epoch_size_neg, dtype=int))
    shuffle(which_dataset)
    which_dataset = torch.Tensor(which_dataset).cuda()

    which_domain = np.random.permutation(number_domain)

    # training loop
    time_arr = np.zeros(10)
    for iteration in tqdm(range(start_iter, max_iter)):
        t0 = time.time()
        if args.multidomain:
            curr_domain = which_domain[iteration % len(which_domain)]
        else:
            curr_domain = 0

        # if new epoch (not including the very first iteration)
        if (iteration != start_iter) and (iteration % epoch_size == 0):
            epoch += 1
            shuffle(which_dataset)
            np.random.shuffle(which_domain)

            print('Saving state, epoch: {}'.format(epoch))
            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(
                    domain_specific_net.state_dict())

            torch.save(
                {
                    'epoch': epoch,
                    'adnet_state_dict': net.state_dict(),
                    'adnet_domain_specific_state_dict': domain_specific_nets,
                    'optimizer_state_dict': optimizer.state_dict(),
                },
                os.path.join(args.save_folder, args.save_file) + 'epoch' +
                repr(epoch) + '.pth')

            # VAL
            # for curr_domain_temp in range(number_domain):
            #     accuracy = []
            #     action_loss_val = []
            #     score_loss_val = []
            #
            #     # load ADNetDomainSpecific with video index
            #     if args.cuda:
            #         net.module.load_domain_specific(domain_specific_nets[curr_domain_temp])
            #     else:
            #         net.load_domain_specific(domain_specific_nets[curr_domain_temp])
            #     for i, temp in enumerate(
            #             [data_loaders_pos_val[curr_domain_temp], data_loaders_neg_val[curr_domain_temp]]):
            #         for images, bbox, action_label, score_label, _ in temp:
            #             images = images.to('cuda', non_blocking=True)
            #             action_label = action_label.to('cuda', non_blocking=True)
            #             score_label = score_label.float().to('cuda', non_blocking=True)
            #
            #             # forward
            #             action_out, score_out = net(images)
            #
            #             if i == 0:  # if positive
            #                 action_l = action_criterion(action_out, torch.max(action_label, 1)[1])
            #                 accuracy.append(
            #                     int(action_label.argmax(axis=1).eq(action_out.argmax(axis=1)).sum()) / len(
            #                         action_label))
            #                 action_loss_val.append(action_l.item())
            #
            #             score_l = score_criterion(score_out, score_label.reshape(-1, 1))
            #             score_loss_val.append(score_l.item())
            #     print("Vid. {}".format(curr_domain))
            #     print("\tAccuracy: {}".format(np.mean(accuracy)))
            #     print("\tScore loss: {}".format(np.mean(score_loss_val)))
            #     print("\tAction loss: {}".format(np.mean(action_loss_val)))
            #     if args.visualize:
            #         writer.add_scalars('data/val_video_{}'.format(curr_domain_temp),
            #                            {'action_loss_val': np.mean(action_loss_val),
            #                             'score_loss_val': np.mean(score_loss_val),
            #                             'total_val': np.mean(score_loss_val) + np.mean(
            #                                 action_loss_val),
            #                             'accuracy': np.mean(accuracy)},
            #                            global_step=epoch)

            if args.visualize:
                writer.add_scalars('data/epoch_loss', {
                    'action_loss_tr':
                    action_loss_tr / epoch_size_pos,
                    'score_loss_tr':
                    score_loss_tr / epoch_size,
                    'total_tr':
                    action_loss_tr / epoch_size_pos +
                    score_loss_tr / epoch_size
                },
                                   global_step=epoch)

            # reset epoch loss counters
            action_loss_tr = 0
            score_loss_tr = 0

        # if new epoch (including the first iteration), initialize the batch iterator
        # or just resuming where batch_iterator_pos and neg haven't been initialized
        if len(batch_iterators_pos_train) == 0 or len(
                batch_iterators_neg_train) == 0:
            # create batch iterator
            for data_loader_pos in data_loaders_pos_train:
                batch_iterators_pos_train.append(iter(data_loader_pos))

            for data_loader_neg in data_loaders_neg_train:
                batch_iterators_neg_train.append(iter(data_loader_neg))

        # if not batch_iterators_pos_train[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_pos_train[curr_domain] = iter(data_loaders_pos_train[curr_domain])
        #
        # if not batch_iterators_neg_train[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_neg_train[curr_domain] = iter(data_loaders_neg_train[curr_domain])

        # load train data
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            try:
                images_list, bbox_list, action_labels, score_label, vid_idx = next(
                    batch_iterators_pos_train[curr_domain])
            except StopIteration:
                batch_iterators_pos_train[curr_domain] = iter(
                    data_loaders_pos_train[curr_domain])
                images_list, bbox_list, action_labels, score_label, vid_idx = next(
                    batch_iterators_pos_train[curr_domain])
        else:
            try:
                images_list, bbox_list, action_labels, score_label, vid_idx = next(
                    batch_iterators_neg_train[curr_domain])
            except StopIteration:
                batch_iterators_neg_train[curr_domain] = iter(
                    data_loaders_neg_train[curr_domain])
                images_list, bbox_list, action_labels, score_label, vid_idx = next(
                    batch_iterators_neg_train[curr_domain])

        # TODO: make sure different obj are paired differenlty, so not always pos with pos
        if args.cuda:
            images_list = images_list.to('cuda', non_blocking=True)
            # bbox = torch.Tensor(bbox.cuda())
            action_labels = action_labels.to('cuda', non_blocking=True)
            score_label = score_label.float().to('cuda', non_blocking=True)

        else:
            images = torch.Tensor(images)
            bbox = torch.Tensor(bbox)
            action_label = torch.Tensor(action_label)
            score_label = torch.Tensor(score_label)

        # TRAIN
        net.train()
        action_out, score_out = net(images_list)

        # load ADNetDomainSpecific with video index
        if args.cuda:
            net.module.load_domain_specific(domain_specific_nets[curr_domain])
        else:
            net.load_domain_specific(domain_specific_nets[curr_domain])

        # backprop
        optimizer.zero_grad()
        accuracy_arr = []
        score_l = score_criterion(
            score_out,
            torch.cat((score_label.reshape(-1, 1), score_label.reshape(-1, 1)),
                      dim=1))
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            action_l = action_criterion(
                action_out,
                action_labels.reshape(-1,
                                      num_obj_to_track * opts['num_actions']))
            loss = action_l + score_l
            for i in range(num_obj_to_track):
                accuracy_arr.append(int(action_labels[:, i, :].argmax(axis=1).eq(
                    action_out[:, i * opts['num_actions']:(i + 1) * opts['num_actions']].argmax(axis=1)).sum()) \
                                    / len(action_labels))
        else:
            action_l = -1
            accuracy_arr = [-1] * num_obj_to_track
            loss = score_l

        loss.backward()
        optimizer.step()

        if action_l != -1:
            action_loss_tr += action_l.item()
        score_loss_tr += score_l.item()

        # save the ADNetDomainSpecific back to their module
        if args.cuda:
            domain_specific_nets[curr_domain].load_weights_from_adnet(
                net.module)
        else:
            domain_specific_nets[curr_domain].load_weights_from_adnet(net)

        if args.visualize:
            if action_l != -1:
                writer.add_scalars(
                    'data/iter_loss', {
                        'action_loss_tr': action_l.item(),
                        'score_loss_tr': score_l.item(),
                        'total_tr': (action_l.item() + score_l.item())
                    },
                    global_step=iteration)
            else:
                writer.add_scalars('data/iter_loss', {
                    'score_loss_tr': score_l.item(),
                    'total_tr': score_l.item()
                },
                                   global_step=iteration)
            for i in range(num_obj_to_track):
                accuracy = accuracy_arr[i]
                if accuracy >= 0:
                    writer.add_scalars('data/iter_acc_{}'.format(i),
                                       {'accuracy_tr': accuracy},
                                       global_step=iteration)

        t1 = time.time()
        time_arr[iteration % 10] = t1 - t0

        if iteration % 10 == 0:
            # print('Avg. 10 iter time: %.4f sec.' % time_arr.sum())
            # print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data.item()), end=' ')
            if args.visualize and args.send_images_to_visualization:
                random_batch_index = np.random.randint(images.size(0))
                writer.add_image('image',
                                 images.data[random_batch_index].cpu().numpy(),
                                 random_batch_index)

        if args.visualize:
            writer.add_scalars('data/time', {'time_10_it': time_arr.sum()},
                               global_step=iteration)

        if iteration % 5000 == 0:
            print('Saving state, iter:', iteration)

            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(
                    domain_specific_net.state_dict())

            torch.save(
                {
                    'epoch': epoch,
                    'adnet_state_dict': net.state_dict(),
                    'adnet_domain_specific_state_dict': domain_specific_nets,
                    'optimizer_state_dict': optimizer.state_dict(),
                },
                os.path.join(args.save_folder, args.save_file) +
                repr(iteration) + '_epoch' + repr(epoch) + '.pth')

    # final save
    torch.save(
        {
            'epoch': epoch,
            'adnet_state_dict': net.state_dict(),
            'adnet_domain_specific_state_dict': domain_specific_nets,
            'optimizer_state_dict': optimizer.state_dict(),
        },
        os.path.join(args.save_folder, args.save_file) + '.pth')

    return net, domain_specific_nets, train_videos
Esempio n. 3
0
    opts['minibatch_size'] = 128
    # train with supervised learning
    _, _, train_videos = adnet_train_sl(args, opts)
    args.resume = os.path.join(args.save_folder, args.save_file) + '.pth'

    # reinitialize the network with network from SL
    net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True,
                                      multidomain=args.multidomain)

    args.start_epoch = 0
    args.start_iter = 0

else:
    assert args.resume is not None, \
        "Please put result of supervised learning or reinforcement learning with --resume (filename)"
    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    if args.start_iter == 0:  # means the weight came from the SL
        net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True, multidomain=args.multidomain)
    else:  # resume the adnet
        net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=False, multidomain=args.multidomain)

if args.cuda:
    net = nn.DataParallel(net)
    cudnn.benchmark = True

    net = net.cuda()

# Reinforcement Learning part
opts['minibatch_size'] = 32
Esempio n. 4
0
def do():
    parser = argparse.ArgumentParser(description='ADNet training')
    parser.add_argument('--adnet_mot',
                        default=False,
                        type=str,
                        help='Whether to test or train.')

    parser.add_argument('--test',
                        default=False,
                        type=str,
                        help='Whether to test or train.')
    # parser.add_argument('--resume', default=None, type=str, help='Resume from checkpoint')
    parser.add_argument('--mot',
                        default=False,
                        type=bool,
                        help='Perform MOT tracking')
    parser.add_argument('--resume',
                        default='weights/ADNet_RL_FINAL.pth',
                        type=str,
                        help='Resume from checkpoint')
    parser.add_argument('--num_workers',
                        default=4,
                        type=int,
                        help='Number of workers used in dataloading')
    parser.add_argument(
        '--start_iter',
        default=0,
        type=int,
        help=
        'Begin counting iterations starting from this value (should be used with resume)'
    )
    parser.add_argument('--cuda',
                        default=True,
                        type=str2bool,
                        help='Use cuda to train model')
    parser.add_argument('--gamma',
                        default=0.1,
                        type=float,
                        help='Gamma update for SGD')
    parser.add_argument('--visualize',
                        default=True,
                        type=str2bool,
                        help='Use tensorboardx to for loss visualization')
    parser.add_argument(
        '--send_images_to_visualization',
        type=str2bool,
        default=False,
        help=
        'Sample a random image from each 10th batch, send it to visdom after augmentations step'
    )
    parser.add_argument('--save_folder',
                        default='weights',
                        help='Location to save checkpoint models')

    parser.add_argument('--save_file',
                        default='ADNet_SL_MOT',
                        type=str,
                        help='save file part of file name for SL')
    parser.add_argument('--save_file_RL',
                        default='ADNet_RL_',
                        type=str,
                        help='save file part of file name for RL')
    parser.add_argument('--start_epoch',
                        default=0,
                        type=int,
                        help='Begin counting epochs starting from this value')

    parser.add_argument('--run_supervised',
                        default=False,
                        type=str2bool,
                        help='Whether to run supervised learning or not')

    parser.add_argument(
        '--multidomain',
        default=True,
        type=str2bool,
        help='Separating weight for each videos (default) or not')

    parser.add_argument(
        '--save_result_images',
        default=False,
        type=str2bool,
        help='Whether to save the results or not. Save folder: images/')
    parser.add_argument('--display_images',
                        default=True,
                        type=str2bool,
                        help='Whether to display images or not')

    args = parser.parse_args()

    # Supervised Learning part
    if args.run_supervised:
        opts['minibatch_size'] = 256
        # train with supervised learning
        if args.test:
            args.save_file += "test"
            _, _, train_videos = adnet_test_sl(args, opts, mot=args.mot)
        else:
            if args.adnet_mot:
                _, _, train_videos = adnet_train_sl_mot(args,
                                                        opts,
                                                        mot=args.mot)
            else:
                _, _, train_videos = adnet_train_sl(args, opts, mot=args.mot)
        args.resume = os.path.join(args.save_folder, args.save_file) + '.pth'

        # reinitialize the network with network from SL
        net, domain_specific_nets = adnet(
            opts,
            trained_file=args.resume,
            random_initialize_domain_specific=True,
            multidomain=args.multidomain)

        args.start_epoch = 0
        args.start_iter = 0

    else:
        assert args.resume is not None, \
            "Please put result of supervised learning or reinforcement learning with --resume (filename)"
        train_videos = get_train_videos(opts)
        opts['num_videos'] = len(train_videos['video_names'])

        if False and args.start_iter == 0:  # means the weight came from the SL
            net, domain_specific_nets = adnet_mot(
                opts,
                trained_file=args.resume,
                random_initialize_domain_specific=True,
                multidomain=args.multidomain)
        else:  # resume the adnet
            if args.adnet_mot:
                net, domain_specific_nets = adnet_mot(
                    opts,
                    trained_file=args.resume,
                    random_initialize_domain_specific=False,
                    multidomain=args.multidomain)
            else:
                net, domain_specific_nets = adnet(
                    opts,
                    trained_file=args.resume,
                    random_initialize_domain_specific=False,
                    multidomain=args.multidomain)
    if True:
        if args.cuda:
            net = nn.DataParallel(net)
            cudnn.benchmark = True

            net = net.cuda()

        # Reinforcement Learning part
        opts['minibatch_size'] = opts['train']['RL_steps']
        if args.adnet_mot:
            net = adnet_train_rl_mot(net, domain_specific_nets, train_videos,
                                     opts, args, 2)
        else:
            net = adnet_train_rl(net, domain_specific_nets, train_videos, opts,
                                 args)
Esempio n. 5
0
def adnet_test_sl(args, opts, mot):
    if torch.cuda.is_available():
        if args.cuda:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        if not args.cuda:
            print(
                "WARNING: It looks like you have a CUDA device, but aren't " +
                "using CUDA.\nRun with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    if args.visualize:
        writer = SummaryWriter(
            log_dir=os.path.join('tensorboardx_log', args.save_file))

    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    net, domain_specific_nets = adnet(opts=opts,
                                      trained_file=args.resume,
                                      multidomain=args.multidomain)

    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True

        net = net.cuda()

    net.eval()

    action_criterion = nn.CrossEntropyLoss()
    score_criterion = nn.BCELoss()

    print('generating Supervised Learning dataset..')
    # dataset = SLDataset(train_videos, opts, transform=
    if mot:
        datasets_pos, datasets_neg = initialize_pos_neg_dataset_mot(
            train_videos, opts, transform=ADNet_Augmentation(opts))
    else:
        datasets_pos, datasets_neg = initialize_pos_neg_dataset(
            train_videos, opts, transform=ADNet_Augmentation(opts))
    number_domain = opts['num_videos']
    assert number_domain == len(
        datasets_pos
    ), "Num videos given in opts is incorrect! It should be {}".format(
        len(datasets_neg))

    batch_iterators_pos_val = []
    batch_iterators_neg_val = []

    # calculating number of data
    len_dataset_pos = 0
    len_dataset_neg = 0
    for dataset_pos in datasets_pos:
        len_dataset_pos += len(dataset_pos)
    for dataset_neg in datasets_neg:
        len_dataset_neg += len(dataset_neg)

    epoch_size_pos = len_dataset_pos // opts['minibatch_size']
    epoch_size_neg = len_dataset_neg // opts['minibatch_size']
    epoch_size = epoch_size_pos + epoch_size_neg  # 1 epoch, how many iterations
    print("1 epoch = " + str(epoch_size) + " iterations")

    max_iter = opts['numEpoch'] * epoch_size
    print("maximum iteration = " + str(max_iter))

    data_loaders_pos_val = []
    data_loaders_neg_val = []

    for dataset_pos in datasets_pos:
        data_loaders_pos_val.append(
            data.DataLoader(dataset_pos,
                            opts['minibatch_size'],
                            num_workers=2,
                            shuffle=True,
                            pin_memory=True))
    for dataset_neg in datasets_neg:
        data_loaders_neg_val.append(
            data.DataLoader(dataset_neg,
                            opts['minibatch_size'],
                            num_workers=2,
                            shuffle=True,
                            pin_memory=True))

    net.eval()

    for curr_domain in range(number_domain):
        accuracy = []
        action_loss_val = []
        score_loss_val = []

        # load ADNetDomainSpecific with video index
        if args.cuda:
            net.module.load_domain_specific(domain_specific_nets[curr_domain])
        else:
            net.load_domain_specific(domain_specific_nets[curr_domain])
        for i, temp in enumerate([
                data_loaders_pos_val[curr_domain],
                data_loaders_neg_val[curr_domain]
        ]):
            dont_show = False
            for images, bbox, action_label, score_label, indices in tqdm(temp):
                images = images.to('cuda', non_blocking=True)
                action_label = action_label.to('cuda', non_blocking=True)
                score_label = score_label.float().to('cuda', non_blocking=True)

                # forward
                action_out, score_out = net(images)

                if i == 0:  # if positive
                    action_l = action_criterion(action_out,
                                                torch.max(action_label, 1)[1])
                    action_loss_val.append(action_l.item())
                    accuracy.append(
                        int(
                            action_label.argmax(axis=1).eq(
                                action_out.argmax(axis=1)).sum()) /
                        len(action_label))

                score_l = score_criterion(score_out,
                                          score_label.reshape(-1, 1))
                score_loss_val.append(score_l.item())

                if args.display_images and not dont_show:
                    if i == 0:
                        dataset = datasets_pos[curr_domain]
                        color = (0, 255, 0)
                        conf = 1
                    else:
                        dataset = datasets_neg[curr_domain]
                        color = (0, 0, 255)
                        conf = 0
                    for j, index in enumerate(indices):
                        im = cv2.imread(dataset.train_db['img_path'][index])
                        bbox = dataset.train_db['bboxes'][index]
                        action_label = np.array(
                            dataset.train_db['labels'][index])
                        cv2.rectangle(im, (bbox[0], bbox[1]),
                                      (bbox[0] + bbox[2], bbox[1] + bbox[3]),
                                      color, 2)

                        print("\n\nTarget actions: {}".format(
                            action_label.argmax()))
                        print("Predicted actions: {}".format(
                            action_out.data[j].argmax()))

                        print("Target conf: {}".format(conf))
                        print("Predicted conf: {}".format(score_out.data[j]))
                        # print("Score loss: {}".format(score_l.item()))
                        # print("Action loss: {}".format(action_l.item()))
                        cv2.imshow("Test", im)
                        key = cv2.waitKey(0) & 0xFF

                        # if the `q` key was pressed, break from the loop
                        if key == ord("q"):
                            dont_show = True
                            break
                        elif key == ord("s"):
                            cv2.imwrite(
                                "vid {} t:{} p:{} c:{}.png".format(
                                    curr_domain, action_label.argmax(),
                                    action_out.data[i].argmax(),
                                    score_out.data[i].item()), im)

        print("Vid. {}".format(curr_domain))
        print("\tAccuracy: {}".format(np.mean(accuracy)))
        print("\tScore loss: {}".format(np.mean(score_loss_val)))
        print("\tAction loss: {}".format(np.mean(action_loss_val)))

    sys.exit(0)
    return net, domain_specific_nets, train_videos