Example #1
0
# dataset_root = os.path.join('../datasets/data', opts['test_db'])
dataset_root = os.path.join('../datasets/data', 'vot15')
vid_folders = []
for filename in os.listdir(dataset_root):
    if os.path.isdir(os.path.join(dataset_root, filename)):
        vid_folders.append(filename)
vid_folders.sort(key=str.lower)
# all_precisions = []

save_root = args.save_result_images
save_root_npy = args.save_result_npy

for vid_folder in vid_folders:
    print('Loading {}...'.format(args.weight_file))
    opts['num_videos'] = 1
    net, domain_nets = adnet(opts, trained_file=args.weight_file, random_initialize_domain_specific=True, vid_index=args.vid_index)
    net.train()
    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True

    if args.cuda:
        net = net.cuda()

    if args.save_result_images is not None:
        args.save_result_images = os.path.join(save_root, vid_folder)
        if not os.path.exists(args.save_result_images):
            os.mkdir(args.save_result_images)

    args.save_result_npy = os.path.join(save_root_npy, vid_folder)
def adnet_train_sl(args, opts):

    if torch.cuda.is_available():
        if args.cuda:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        if not args.cuda:
            print(
                "WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    if args.visualize:
        writer = SummaryWriter(log_dir=os.path.join('tensorboardx_log', args.save_file))

    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    net, domain_specific_nets = adnet(opts=opts, trained_file=args.resume, multidomain=args.multidomain)

    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True

        net = net.cuda()

    if args.cuda:
        optimizer = optim.SGD([
            {'params': net.module.base_network.parameters(), 'lr': 1e-4},
            {'params': net.module.fc4_5.parameters()},
            {'params': net.module.fc6.parameters()},
            {'params': net.module.fc7.parameters()}],  # as action dynamic is zero, it doesn't matter
            lr=1e-3, momentum=opts['train']['momentum'], weight_decay=opts['train']['weightDecay'])
    else:
        optimizer = optim.SGD([
            {'params': net.base_network.parameters(), 'lr': 1e-4},
            {'params': net.fc4_5.parameters()},
            {'params': net.fc6.parameters()},
            {'params': net.fc7.parameters()}],
            lr=1e-3, momentum=opts['train']['momentum'], weight_decay=opts['train']['weightDecay'])

    if args.resume:
        # net.load_weights(args.resume)
        checkpoint = torch.load(args.resume)

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    net.train()


    if not args.resume:
        print('Initializing weights...')

        if args.cuda:
            scal = torch.Tensor([0.01])
            # fc 4
            nn.init.normal_(net.module.fc4_5[0].weight.data)
            net.module.fc4_5[0].weight.data = net.module.fc4_5[0].weight.data * scal.expand_as(net.module.fc4_5[0].weight.data)
            net.module.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.module.fc4_5[3].weight.data)
            net.module.fc4_5[3].weight.data = net.module.fc4_5[3].weight.data * scal.expand_as(net.module.fc4_5[3].weight.data)
            net.module.fc4_5[3].bias.data.fill_(0.1)

            # fc 6
            nn.init.normal_(net.module.fc6.weight.data)
            net.module.fc6.weight.data = net.module.fc6.weight.data * scal.expand_as(net.module.fc6.weight.data)
            net.module.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.module.fc7.weight.data)
            net.module.fc7.weight.data = net.module.fc7.weight.data * scal.expand_as(net.module.fc7.weight.data)
            net.module.fc7.bias.data.fill_(0)
        else:
            scal = torch.Tensor([0.01])
            # fc 4
            nn.init.normal_(net.fc4_5[0].weight.data)
            net.fc4_5[0].weight.data = net.fc4_5[0].weight.data * scal.expand_as(net.fc4_5[0].weight.data )
            net.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.fc4_5[3].weight.data)
            net.fc4_5[3].weight.data = net.fc4_5[3].weight.data * scal.expand_as(net.fc4_5[3].weight.data)
            net.fc4_5[3].bias.data.fill_(0.1)
            # fc 6
            nn.init.normal_(net.fc6.weight.data)
            net.fc6.weight.data = net.fc6.weight.data * scal.expand_as(net.fc6.weight.data)
            net.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.fc7.weight.data)
            net.fc7.weight.data = net.fc7.weight.data * scal.expand_as(net.fc7.weight.data)
            net.fc7.bias.data.fill_(0)

    action_criterion = nn.CrossEntropyLoss()
    score_criterion = nn.CrossEntropyLoss()


    print('generating Supervised Learning dataset..')
    # dataset = SLDataset(train_videos, opts, transform=

    datasets_pos, datasets_neg = initialize_pos_neg_dataset(train_videos, opts, transform=ADNet_Augmentation(opts))
    number_domain = opts['num_videos']

    batch_iterators_pos = []
    batch_iterators_neg = []

    # calculating number of data
    len_dataset_pos = 0
    len_dataset_neg = 0
    for dataset_pos in datasets_pos:
        len_dataset_pos += len(dataset_pos)
    for dataset_neg in datasets_neg:
        len_dataset_neg += len(dataset_neg)

    epoch_size_pos = len_dataset_pos // opts['minibatch_size']
    epoch_size_neg = len_dataset_neg // opts['minibatch_size']
    epoch_size = epoch_size_pos + epoch_size_neg  # 1 epoch, how many iterations
    print("1 epoch = " + str(epoch_size) + " iterations")

    max_iter = opts['numEpoch'] * epoch_size
    print("maximum iteration = " + str(max_iter))

    data_loaders_pos = []
    data_loaders_neg = []

    for dataset_pos in datasets_pos:
        data_loaders_pos.append(data.DataLoader(dataset_pos, opts['minibatch_size'], num_workers=args.num_workers, shuffle=True, pin_memory=True))
    for dataset_neg in datasets_neg:
        data_loaders_neg.append(data.DataLoader(dataset_neg, opts['minibatch_size'], num_workers=args.num_workers, shuffle=True, pin_memory=True))

    epoch = args.start_epoch
    if epoch != 0 and args.start_iter == 0:
        start_iter = epoch * epoch_size
    else:
        start_iter = args.start_iter

    which_dataset = list(np.full(epoch_size_pos, fill_value=1))
    which_dataset.extend(np.zeros(epoch_size_neg, dtype=int))
    shuffle(which_dataset)

    which_domain = np.random.permutation(number_domain)

    action_loss = 0
    score_loss = 0

    # training loop
    for iteration in range(start_iter, max_iter):
        if args.multidomain:
            curr_domain = which_domain[iteration % len(which_domain)]
        else:
            curr_domain = 0
        # if new epoch (not including the very first iteration)
        if (iteration != start_iter) and (iteration % epoch_size == 0):
            epoch += 1
            shuffle(which_dataset)
            np.random.shuffle(which_domain)

            print('Saving state, epoch:', epoch)
            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(domain_specific_net.state_dict())

            torch.save({
                'epoch': epoch,
                'adnet_state_dict': net.state_dict(),
                'adnet_domain_specific_state_dict': domain_specific_nets,
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(args.save_folder, args.save_file) +
                       'epoch' + repr(epoch) + '.pth')

            if args.visualize:
                writer.add_scalars('data/epoch_loss', {'action_loss': action_loss / epoch_size,
                                                       'score_loss': score_loss / epoch_size,
                                                       'total': (action_loss + score_loss) / epoch_size}, global_step=epoch)

            # reset epoch loss counters
            action_loss = 0
            score_loss = 0

        # if new epoch (including the first iteration), initialize the batch iterator
        # or just resuming where batch_iterator_pos and neg haven't been initialized
        if iteration % epoch_size == 0 or len(batch_iterators_pos) == 0 or len(batch_iterators_neg) == 0:
            # create batch iterator
            for data_loader_pos in data_loaders_pos:
                batch_iterators_pos.append(iter(data_loader_pos))
            for data_loader_neg in data_loaders_neg:
                batch_iterators_neg.append(iter(data_loader_neg))

        # if not batch_iterators_pos[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_pos[curr_domain] = iter(data_loaders_pos[curr_domain])
        #
        # if not batch_iterators_neg[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_neg[curr_domain] = iter(data_loaders_neg[curr_domain])

        # load train data
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            try:
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_pos[curr_domain])
            except StopIteration:
                batch_iterators_pos[curr_domain] = iter(data_loaders_pos[curr_domain])
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_pos[curr_domain])
        else:
            try:
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_neg[curr_domain])
            except StopIteration:
                batch_iterators_neg[curr_domain] = iter(data_loaders_neg[curr_domain])
                images, bbox, action_label, score_label, vid_idx = next(batch_iterators_neg[curr_domain])

        # TODO: check if this requires grad is really false like in Variable
        if args.cuda:
            images = torch.Tensor(images.cuda())
            bbox = torch.Tensor(bbox.cuda())
            action_label = torch.Tensor(action_label.cuda())
            score_label = torch.Tensor(score_label.float().cuda())

        else:
            images = torch.Tensor(images)
            bbox = torch.Tensor(bbox)
            action_label = torch.Tensor(action_label)
            score_label = torch.Tensor(score_label)

        t0 = time.time()

        # load ADNetDomainSpecific with video index
        if args.cuda:
            net.module.load_domain_specific(domain_specific_nets[curr_domain])
        else:
            net.load_domain_specific(domain_specific_nets[curr_domain])

        # forward
        action_out, score_out = net(images)

        # backprop
        optimizer.zero_grad()
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            action_l = action_criterion(action_out, torch.max(action_label, 1)[1])
        else:
            action_l = torch.Tensor([0])
        score_l = score_criterion(score_out, score_label.long())
        loss = action_l + score_l
        loss.backward()
        optimizer.step()

        action_loss += action_l.item()
        score_loss += score_l.item()

        # save the ADNetDomainSpecific back to their module
        if args.cuda:
            domain_specific_nets[curr_domain].load_weights_from_adnet(net.module)
        else:
            domain_specific_nets[curr_domain].load_weights_from_adnet(net)

        t1 = time.time()

        if iteration % 10 == 0:
            print('Timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data.item()), end=' ')
            if args.visualize and args.send_images_to_visualization:
                random_batch_index = np.random.randint(images.size(0))
                writer.add_image('image', images.data[random_batch_index].cpu().numpy(), random_batch_index)

        if args.visualize:
            writer.add_scalars('data/iter_loss', {'action_loss': action_l.item(),
                                                  'score_loss': score_l.item(),
                                                  'total': (action_l.item() + score_l.item())}, global_step=iteration)
            # hacky fencepost solution for 0th epoch plot
            if iteration == 0:
                writer.add_scalars('data/epoch_loss', {'action_loss': action_loss,
                                                       'score_loss': score_loss,
                                                       'total': (action_loss + score_loss)}, global_step=epoch)

        if iteration % 5000 == 0:
            print('Saving state, iter:', iteration)

            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(domain_specific_net.state_dict())

            torch.save({
                'epoch': epoch,
                'adnet_state_dict': net.state_dict(),
                'adnet_domain_specific_state_dict': domain_specific_nets,
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(args.save_folder, args.save_file) +
                       repr(iteration) + '_epoch' + repr(epoch) +'.pth')

    # final save
    torch.save({
        'epoch': epoch,
        'adnet_state_dict': net.state_dict(),
        'adnet_domain_specific_state_dict': domain_specific_nets,
        'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(args.save_folder, args.save_file) + '.pth')

    return net, domain_specific_nets, train_videos
Example #3
0
parser.add_argument('--multidomain', default=True, type=str2bool, help='Separating weight for each videos (default) or not')

parser.add_argument('--save_result_images', default=True, type=str2bool, help='Whether to save the results or not. Save folder: images/')
parser.add_argument('--display_images', default=True, type=str2bool, help='Whether to display images or not')

args = parser.parse_args()

# Supervised Learning part
if args.run_supervised:
    opts['minibatch_size'] = 128
    # train with supervised learning
    _, _, train_videos = adnet_train_sl(args, opts)
    args.resume = os.path.join(args.save_folder, args.save_file) + '.pth'

    # reinitialize the network with network from SL
    net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True,
                                      multidomain=args.multidomain)

    args.start_epoch = 0
    args.start_iter = 0

else:
    assert args.resume is not None, \
        "Please put result of supervised learning or reinforcement learning with --resume (filename)"
    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    if args.start_iter == 0:  # means the weight came from the SL
        net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True, multidomain=args.multidomain)
    else:  # resume the adnet
        net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=False, multidomain=args.multidomain)
Example #4
0
def do():
    parser = argparse.ArgumentParser(description='ADNet training')
    parser.add_argument('--adnet_mot',
                        default=False,
                        type=str,
                        help='Whether to test or train.')

    parser.add_argument('--test',
                        default=False,
                        type=str,
                        help='Whether to test or train.')
    # parser.add_argument('--resume', default=None, type=str, help='Resume from checkpoint')
    parser.add_argument('--mot',
                        default=False,
                        type=bool,
                        help='Perform MOT tracking')
    parser.add_argument('--resume',
                        default='weights/ADNet_RL_FINAL.pth',
                        type=str,
                        help='Resume from checkpoint')
    parser.add_argument('--num_workers',
                        default=4,
                        type=int,
                        help='Number of workers used in dataloading')
    parser.add_argument(
        '--start_iter',
        default=0,
        type=int,
        help=
        'Begin counting iterations starting from this value (should be used with resume)'
    )
    parser.add_argument('--cuda',
                        default=True,
                        type=str2bool,
                        help='Use cuda to train model')
    parser.add_argument('--gamma',
                        default=0.1,
                        type=float,
                        help='Gamma update for SGD')
    parser.add_argument('--visualize',
                        default=True,
                        type=str2bool,
                        help='Use tensorboardx to for loss visualization')
    parser.add_argument(
        '--send_images_to_visualization',
        type=str2bool,
        default=False,
        help=
        'Sample a random image from each 10th batch, send it to visdom after augmentations step'
    )
    parser.add_argument('--save_folder',
                        default='weights',
                        help='Location to save checkpoint models')

    parser.add_argument('--save_file',
                        default='ADNet_SL_MOT',
                        type=str,
                        help='save file part of file name for SL')
    parser.add_argument('--save_file_RL',
                        default='ADNet_RL_',
                        type=str,
                        help='save file part of file name for RL')
    parser.add_argument('--start_epoch',
                        default=0,
                        type=int,
                        help='Begin counting epochs starting from this value')

    parser.add_argument('--run_supervised',
                        default=False,
                        type=str2bool,
                        help='Whether to run supervised learning or not')

    parser.add_argument(
        '--multidomain',
        default=True,
        type=str2bool,
        help='Separating weight for each videos (default) or not')

    parser.add_argument(
        '--save_result_images',
        default=False,
        type=str2bool,
        help='Whether to save the results or not. Save folder: images/')
    parser.add_argument('--display_images',
                        default=True,
                        type=str2bool,
                        help='Whether to display images or not')

    args = parser.parse_args()

    # Supervised Learning part
    if args.run_supervised:
        opts['minibatch_size'] = 256
        # train with supervised learning
        if args.test:
            args.save_file += "test"
            _, _, train_videos = adnet_test_sl(args, opts, mot=args.mot)
        else:
            if args.adnet_mot:
                _, _, train_videos = adnet_train_sl_mot(args,
                                                        opts,
                                                        mot=args.mot)
            else:
                _, _, train_videos = adnet_train_sl(args, opts, mot=args.mot)
        args.resume = os.path.join(args.save_folder, args.save_file) + '.pth'

        # reinitialize the network with network from SL
        net, domain_specific_nets = adnet(
            opts,
            trained_file=args.resume,
            random_initialize_domain_specific=True,
            multidomain=args.multidomain)

        args.start_epoch = 0
        args.start_iter = 0

    else:
        assert args.resume is not None, \
            "Please put result of supervised learning or reinforcement learning with --resume (filename)"
        train_videos = get_train_videos(opts)
        opts['num_videos'] = len(train_videos['video_names'])

        if False and args.start_iter == 0:  # means the weight came from the SL
            net, domain_specific_nets = adnet_mot(
                opts,
                trained_file=args.resume,
                random_initialize_domain_specific=True,
                multidomain=args.multidomain)
        else:  # resume the adnet
            if args.adnet_mot:
                net, domain_specific_nets = adnet_mot(
                    opts,
                    trained_file=args.resume,
                    random_initialize_domain_specific=False,
                    multidomain=args.multidomain)
            else:
                net, domain_specific_nets = adnet(
                    opts,
                    trained_file=args.resume,
                    random_initialize_domain_specific=False,
                    multidomain=args.multidomain)
    if True:
        if args.cuda:
            net = nn.DataParallel(net)
            cudnn.benchmark = True

            net = net.cuda()

        # Reinforcement Learning part
        opts['minibatch_size'] = opts['train']['RL_steps']
        if args.adnet_mot:
            net = adnet_train_rl_mot(net, domain_specific_nets, train_videos,
                                     opts, args, 2)
        else:
            net = adnet_train_rl(net, domain_specific_nets, train_videos, opts,
                                 args)
Example #5
0
def adnet_test_sl(args, opts, mot):
    if torch.cuda.is_available():
        if args.cuda:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        if not args.cuda:
            print(
                "WARNING: It looks like you have a CUDA device, but aren't " +
                "using CUDA.\nRun with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    if args.visualize:
        writer = SummaryWriter(
            log_dir=os.path.join('tensorboardx_log', args.save_file))

    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    net, domain_specific_nets = adnet(opts=opts,
                                      trained_file=args.resume,
                                      multidomain=args.multidomain)

    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True

        net = net.cuda()

    net.eval()

    action_criterion = nn.CrossEntropyLoss()
    score_criterion = nn.BCELoss()

    print('generating Supervised Learning dataset..')
    # dataset = SLDataset(train_videos, opts, transform=
    if mot:
        datasets_pos, datasets_neg = initialize_pos_neg_dataset_mot(
            train_videos, opts, transform=ADNet_Augmentation(opts))
    else:
        datasets_pos, datasets_neg = initialize_pos_neg_dataset(
            train_videos, opts, transform=ADNet_Augmentation(opts))
    number_domain = opts['num_videos']
    assert number_domain == len(
        datasets_pos
    ), "Num videos given in opts is incorrect! It should be {}".format(
        len(datasets_neg))

    batch_iterators_pos_val = []
    batch_iterators_neg_val = []

    # calculating number of data
    len_dataset_pos = 0
    len_dataset_neg = 0
    for dataset_pos in datasets_pos:
        len_dataset_pos += len(dataset_pos)
    for dataset_neg in datasets_neg:
        len_dataset_neg += len(dataset_neg)

    epoch_size_pos = len_dataset_pos // opts['minibatch_size']
    epoch_size_neg = len_dataset_neg // opts['minibatch_size']
    epoch_size = epoch_size_pos + epoch_size_neg  # 1 epoch, how many iterations
    print("1 epoch = " + str(epoch_size) + " iterations")

    max_iter = opts['numEpoch'] * epoch_size
    print("maximum iteration = " + str(max_iter))

    data_loaders_pos_val = []
    data_loaders_neg_val = []

    for dataset_pos in datasets_pos:
        data_loaders_pos_val.append(
            data.DataLoader(dataset_pos,
                            opts['minibatch_size'],
                            num_workers=2,
                            shuffle=True,
                            pin_memory=True))
    for dataset_neg in datasets_neg:
        data_loaders_neg_val.append(
            data.DataLoader(dataset_neg,
                            opts['minibatch_size'],
                            num_workers=2,
                            shuffle=True,
                            pin_memory=True))

    net.eval()

    for curr_domain in range(number_domain):
        accuracy = []
        action_loss_val = []
        score_loss_val = []

        # load ADNetDomainSpecific with video index
        if args.cuda:
            net.module.load_domain_specific(domain_specific_nets[curr_domain])
        else:
            net.load_domain_specific(domain_specific_nets[curr_domain])
        for i, temp in enumerate([
                data_loaders_pos_val[curr_domain],
                data_loaders_neg_val[curr_domain]
        ]):
            dont_show = False
            for images, bbox, action_label, score_label, indices in tqdm(temp):
                images = images.to('cuda', non_blocking=True)
                action_label = action_label.to('cuda', non_blocking=True)
                score_label = score_label.float().to('cuda', non_blocking=True)

                # forward
                action_out, score_out = net(images)

                if i == 0:  # if positive
                    action_l = action_criterion(action_out,
                                                torch.max(action_label, 1)[1])
                    action_loss_val.append(action_l.item())
                    accuracy.append(
                        int(
                            action_label.argmax(axis=1).eq(
                                action_out.argmax(axis=1)).sum()) /
                        len(action_label))

                score_l = score_criterion(score_out,
                                          score_label.reshape(-1, 1))
                score_loss_val.append(score_l.item())

                if args.display_images and not dont_show:
                    if i == 0:
                        dataset = datasets_pos[curr_domain]
                        color = (0, 255, 0)
                        conf = 1
                    else:
                        dataset = datasets_neg[curr_domain]
                        color = (0, 0, 255)
                        conf = 0
                    for j, index in enumerate(indices):
                        im = cv2.imread(dataset.train_db['img_path'][index])
                        bbox = dataset.train_db['bboxes'][index]
                        action_label = np.array(
                            dataset.train_db['labels'][index])
                        cv2.rectangle(im, (bbox[0], bbox[1]),
                                      (bbox[0] + bbox[2], bbox[1] + bbox[3]),
                                      color, 2)

                        print("\n\nTarget actions: {}".format(
                            action_label.argmax()))
                        print("Predicted actions: {}".format(
                            action_out.data[j].argmax()))

                        print("Target conf: {}".format(conf))
                        print("Predicted conf: {}".format(score_out.data[j]))
                        # print("Score loss: {}".format(score_l.item()))
                        # print("Action loss: {}".format(action_l.item()))
                        cv2.imshow("Test", im)
                        key = cv2.waitKey(0) & 0xFF

                        # if the `q` key was pressed, break from the loop
                        if key == ord("q"):
                            dont_show = True
                            break
                        elif key == ord("s"):
                            cv2.imwrite(
                                "vid {} t:{} p:{} c:{}.png".format(
                                    curr_domain, action_label.argmax(),
                                    action_out.data[i].argmax(),
                                    score_out.data[i].item()), im)

        print("Vid. {}".format(curr_domain))
        print("\tAccuracy: {}".format(np.mean(accuracy)))
        print("\tScore loss: {}".format(np.mean(score_loss_val)))
        print("\tAction loss: {}".format(np.mean(action_loss_val)))

    sys.exit(0)
    return net, domain_specific_nets, train_videos
Example #6
0
def process_adnet_test(videos_infos,dataset_start_id, v_start_id,v_end_id,train_videos,save_root,
                        spend_times_share,vid_preds, opts,args, lock):
    siamesenet=''
    if args.useSiamese:
        siamesenet = SiameseNetwork().cuda()
        resume = args.weight_siamese
        # resume = False
        if resume:
            siamesenet.load_weights(resume)

    # print('Loading {}...'.format(args.weight_file))
            
    net, domain_nets = adnet(opts, trained_file=args.weight_file, random_initialize_domain_specific=False)
    net.eval()
    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True
    if args.cuda:
        net = net.cuda()
    if args.cuda:
        net.module.set_phase('test')
    else:
        net.set_phase('test')

    register_ILSVRC()
    cfg = get_cfg()
    cfg.merge_from_file("../../../configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
    # Find a model from detectron2's model zoo. You can either use the https://dl.fbaipublicfiles.... url, or use the following shorthand
    # cfg.MODEL.WEIGHTS ="../datasets/tem/train_output/model_0449999.pth"
    cfg.MODEL.WEIGHTS = args.weight_detector
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 30
    metalog = MetadataCatalog.get("ILSVRC_VID_val")

    predictor = DefaultPredictor(cfg)
    class_names = metalog.get("thing_classes", None)
    for vidx in range(v_start_id, v_end_id):
        # for vidx in range(20):
        vid_folder = videos_infos[vidx]

        if args.save_result_images_bool:
            args.save_result_images = os.path.join(save_root, train_videos['video_names'][vidx])
            if not os.path.exists(args.save_result_images):
                os.makedirs(args.save_result_images)

        vid_pred,spend_time = adnet_test(net,predictor,siamesenet,metalog,class_names, vidx,vid_folder['img_files'], opts, args)
        try:
            lock.acquire()
            spend_times=spend_times_share[0].copy()
            spend_times['predict']+=spend_time['predict']
            spend_times['n_predict_frames'] += spend_time['n_predict_frames']
            spend_times['track'] += spend_time['track']
            spend_times['n_track_frames'] += spend_time['n_track_frames']
            spend_times['readframe'] += spend_time['readframe']
            spend_times['n_readframe'] += spend_time['n_readframe']
            spend_times['append'] += spend_time['append']
            spend_times['n_append'] += spend_time['n_append']
            spend_times['transform'] += spend_time['transform']
            spend_times['n_transform'] += spend_time['n_transform']
            spend_times['argmax_after_forward'] += spend_time['argmax_after_forward']
            spend_times['n_argmax_after_forward'] += spend_time['n_argmax_after_forward']
            spend_times['do_action'] += spend_time['do_action']
            spend_times['n_do_action'] += spend_time['n_do_action']
            spend_times_share[0]=spend_times
            vid_preds[vidx-dataset_start_id]=vid_pred
        except Exception as err:
            raise err
        finally:
            lock.release()
Example #7
0
    # dataset_root = os.path.join('../datasets/data', opts['test_db'])
    # vid_folders = []
    # for filename in os.listdir(dataset_root):
    #     if os.path.isdir(os.path.join(dataset_root,filename)):
    #         vid_folders.append(filename)
    # vid_folders.sort(key=str.lower)
    # all_precisions = []

        save_root = args.save_result_images
    # save_root_npy = args.save_result_npy

    opts['num_videos'] = 1
    if not args.multi_cpu_eval:
        print('Loading {}...'.format(args.weight_file))

        net, domain_nets = adnet(opts, trained_file=args.weight_file, random_initialize_domain_specific=False)
        net.eval()
        if args.cuda:
            net = nn.DataParallel(net)
            cudnn.benchmark = True
        if args.cuda:
            net = net.cuda()
        if args.cuda:
            net.module.set_phase('test')
        else:
            net.set_phase('test')

    if args.test1vid:
        vid_path = args.testVidPath
        vid_folder = vid_path.split('/')[-2]
        # vid_path = "../../../demo/examples/jiaotong2.avi"
Example #8
0
def adnet_train_sl(args, opts, mot):
    if torch.cuda.is_available():
        if args.cuda:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        if not args.cuda:
            print(
                "WARNING: It looks like you have a CUDA device, but aren't " +
                "using CUDA.\nRun with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    if args.visualize:
        writer = SummaryWriter(
            log_dir=os.path.join('tensorboardx_log', args.save_file))

    train_videos = get_train_videos(opts)
    opts['num_videos'] = len(train_videos['video_names'])

    net, domain_specific_nets = adnet(opts=opts,
                                      trained_file=args.resume,
                                      multidomain=args.multidomain)

    if args.cuda:
        net = nn.DataParallel(net)
        cudnn.benchmark = True

        net = net.cuda()

    if args.cuda:
        optimizer = optim.Adam(
            [{
                'params': net.module.base_network.parameters(),
                'lr': 1e-4
            }, {
                'params': net.module.fc4_5.parameters()
            }, {
                'params': net.module.fc6.parameters()
            }, {
                'params': net.module.fc7.parameters()
            }],  # as action dynamic is zero, it doesn't matter
            lr=1e-3,
            weight_decay=opts['train']['weightDecay'])
    else:
        optimizer = optim.SGD([{
            'params': net.base_network.parameters(),
            'lr': 1e-4
        }, {
            'params': net.fc4_5.parameters()
        }, {
            'params': net.fc6.parameters()
        }, {
            'params': net.fc7.parameters()
        }],
                              lr=1e-3,
                              momentum=opts['train']['momentum'],
                              weight_decay=opts['train']['weightDecay'])

    if args.resume:
        # net.load_weights(args.resume)
        checkpoint = torch.load(args.resume)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    net.train()

    if not args.resume:
        print('Initializing weights...')

        if args.cuda:
            norm_std = 0.01
            # fc 4
            nn.init.normal_(net.module.fc4_5[0].weight.data, std=norm_std)
            net.module.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.module.fc4_5[3].weight.data, std=norm_std)
            net.module.fc4_5[3].bias.data.fill_(0.1)

            # fc 6
            nn.init.normal_(net.module.fc6.weight.data, std=norm_std)
            net.module.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.module.fc7.weight.data, std=norm_std)
            net.module.fc7.bias.data.fill_(0)
        else:
            scal = torch.Tensor([0.01])
            # fc 4
            nn.init.normal_(net.fc4_5[0].weight.data)
            net.fc4_5[0].weight.data = net.fc4_5[
                0].weight.data * scal.expand_as(net.fc4_5[0].weight.data)
            net.fc4_5[0].bias.data.fill_(0.1)
            # fc 5
            nn.init.normal_(net.fc4_5[3].weight.data)
            net.fc4_5[3].weight.data = net.fc4_5[
                3].weight.data * scal.expand_as(net.fc4_5[3].weight.data)
            net.fc4_5[3].bias.data.fill_(0.1)
            # fc 6
            nn.init.normal_(net.fc6.weight.data)
            net.fc6.weight.data = net.fc6.weight.data * scal.expand_as(
                net.fc6.weight.data)
            net.fc6.bias.data.fill_(0)
            # fc 7
            nn.init.normal_(net.fc7.weight.data)
            net.fc7.weight.data = net.fc7.weight.data * scal.expand_as(
                net.fc7.weight.data)
            net.fc7.bias.data.fill_(0)

    action_criterion = nn.CrossEntropyLoss()
    score_criterion = nn.BCELoss()

    print('generating Supervised Learning dataset..')
    # dataset = SLDataset(train_videos, opts, transform=
    if mot:
        datasets_pos, datasets_neg = initialize_pos_neg_dataset_mot(
            train_videos, opts, transform=ADNet_Augmentation(opts))
    else:
        datasets_pos, datasets_neg = initialize_pos_neg_dataset(
            train_videos, opts, transform=ADNet_Augmentation(opts))
    number_domain = opts['num_videos']
    assert number_domain == len(
        datasets_pos
    ), "Num videos given in opts is incorrect! It should be {}".format(
        len(datasets_neg))

    batch_iterators_pos_train = []
    batch_iterators_neg_train = []

    # calculating number of data
    len_dataset_pos = 0
    len_dataset_neg = 0
    for dataset_pos in datasets_pos:
        len_dataset_pos += len(dataset_pos)
    for dataset_neg in datasets_neg:
        len_dataset_neg += len(dataset_neg)

    epoch_size_pos = len_dataset_pos // opts['minibatch_size']
    epoch_size_neg = len_dataset_neg // opts['minibatch_size']
    epoch_size = epoch_size_pos + epoch_size_neg  # 1 epoch, how many iterations
    print("1 epoch = " + str(epoch_size) + " iterations")

    max_iter = opts['numEpoch'] * epoch_size
    print("maximum iteration = " + str(max_iter))

    data_loaders_pos_train = []
    data_loaders_pos_val = []

    data_loaders_neg_train = []
    data_loaders_neg_val = []

    for dataset_pos in datasets_pos:
        num_val = int(opts['val_percent'] * len(dataset_pos))
        num_train = len(dataset_pos) - num_val
        train, valid = torch.utils.data.random_split(dataset_pos,
                                                     [num_train, num_val])
        data_loaders_pos_train.append(
            data.DataLoader(train,
                            opts['minibatch_size'],
                            num_workers=2,
                            shuffle=True,
                            pin_memory=True))
        data_loaders_pos_val.append(
            data.DataLoader(valid,
                            opts['minibatch_size'],
                            num_workers=0,
                            shuffle=True,
                            pin_memory=False))
    for dataset_neg in datasets_neg:
        num_val = int(opts['val_percent'] * len(dataset_neg))
        num_train = len(dataset_neg) - num_val
        train, valid = torch.utils.data.random_split(dataset_neg,
                                                     [num_train, num_val])
        data_loaders_neg_train.append(
            data.DataLoader(train,
                            opts['minibatch_size'],
                            num_workers=1,
                            shuffle=True,
                            pin_memory=True))
        data_loaders_neg_val.append(
            data.DataLoader(valid,
                            opts['minibatch_size'],
                            num_workers=0,
                            shuffle=True,
                            pin_memory=False))

    epoch = args.start_epoch
    if epoch != 0 and args.start_iter == 0:
        start_iter = epoch * epoch_size
    else:
        start_iter = args.start_iter

    which_dataset = list(np.full(epoch_size_pos, fill_value=1))
    which_dataset.extend(np.zeros(epoch_size_neg, dtype=int))
    shuffle(which_dataset)
    which_dataset = torch.Tensor(which_dataset).cuda()

    which_domain = np.random.permutation(number_domain)

    action_loss_tr = 0
    score_loss_tr = 0

    # training loop
    time_arr = np.zeros(10)
    for iteration in tqdm(range(start_iter, max_iter)):
        t0 = time.time()
        if args.multidomain:
            curr_domain = which_domain[iteration % len(which_domain)]
        else:
            curr_domain = 0

        # if new epoch (not including the very first iteration)
        if (iteration != start_iter) and (iteration % epoch_size == 0):
            epoch += 1
            shuffle(which_dataset)
            np.random.shuffle(which_domain)

            print('Saving state, epoch: {}'.format(epoch))
            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(
                    domain_specific_net.state_dict())

            torch.save(
                {
                    'epoch': epoch,
                    'adnet_state_dict': net.state_dict(),
                    'adnet_domain_specific_state_dict': domain_specific_nets,
                    'optimizer_state_dict': optimizer.state_dict(),
                },
                os.path.join(args.save_folder, args.save_file) + 'epoch' +
                repr(epoch) + '.pth')

            # VAL
            for curr_domain_temp in range(number_domain):
                accuracy = []
                action_loss_val = []
                score_loss_val = []

                # load ADNetDomainSpecific with video index
                if args.cuda:
                    net.module.load_domain_specific(
                        domain_specific_nets[curr_domain_temp])
                else:
                    net.load_domain_specific(
                        domain_specific_nets[curr_domain_temp])
                for i, temp in enumerate([
                        data_loaders_pos_val[curr_domain_temp],
                        data_loaders_neg_val[curr_domain_temp]
                ]):
                    for images, bbox, action_label, score_label, _ in temp:
                        images = images.to('cuda', non_blocking=True)
                        action_label = action_label.to('cuda',
                                                       non_blocking=True)
                        score_label = score_label.float().to('cuda',
                                                             non_blocking=True)

                        # forward
                        action_out, score_out = net(images)

                        if i == 0:  # if positive
                            action_l = action_criterion(
                                action_out,
                                torch.max(action_label, 1)[1])
                            accuracy.append(
                                int(
                                    action_label.argmax(axis=1).eq(
                                        action_out.argmax(axis=1)).sum()) /
                                len(action_label))
                            action_loss_val.append(action_l.item())

                        score_l = score_criterion(score_out,
                                                  score_label.reshape(-1, 1))
                        score_loss_val.append(score_l.item())
                print("Vid. {}".format(curr_domain))
                print("\tAccuracy: {}".format(np.mean(accuracy)))
                print("\tScore loss: {}".format(np.mean(score_loss_val)))
                print("\tAction loss: {}".format(np.mean(action_loss_val)))
                if args.visualize:
                    writer.add_scalars(
                        'data/val_video_{}'.format(curr_domain_temp), {
                            'action_loss_val':
                            np.mean(action_loss_val),
                            'score_loss_val':
                            np.mean(score_loss_val),
                            'total_val':
                            np.mean(score_loss_val) + np.mean(action_loss_val),
                            'accuracy':
                            np.mean(accuracy)
                        },
                        global_step=epoch)

            if args.visualize:
                writer.add_scalars('data/epoch_loss', {
                    'action_loss_tr':
                    action_loss_tr / epoch_size_pos,
                    'score_loss_tr':
                    score_loss_tr / epoch_size,
                    'total_tr':
                    action_loss_tr / epoch_size_pos +
                    score_loss_tr / epoch_size
                },
                                   global_step=epoch)

            # reset epoch loss counters
            action_loss_tr = 0
            score_loss_tr = 0

        # if new epoch (including the first iteration), initialize the batch iterator
        # or just resuming where batch_iterator_pos and neg haven't been initialized
        if len(batch_iterators_pos_train) == 0 or len(
                batch_iterators_neg_train) == 0:
            # create batch iterator
            for data_loader_pos in data_loaders_pos_train:
                batch_iterators_pos_train.append(iter(data_loader_pos))

            for data_loader_neg in data_loaders_neg_train:
                batch_iterators_neg_train.append(iter(data_loader_neg))

        # if not batch_iterators_pos_train[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_pos_train[curr_domain] = iter(data_loaders_pos_train[curr_domain])
        #
        # if not batch_iterators_neg_train[curr_domain]:
        #     # create batch iterator
        #     batch_iterators_neg_train[curr_domain] = iter(data_loaders_neg_train[curr_domain])

        # load train data
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            try:
                images, bbox, action_label, score_label, vid_idx = next(
                    batch_iterators_pos_train[curr_domain])
            except StopIteration:
                batch_iterators_pos_train[curr_domain] = iter(
                    data_loaders_pos_train[curr_domain])
                images, bbox, action_label, score_label, vid_idx = next(
                    batch_iterators_pos_train[curr_domain])
        else:
            try:
                images, bbox, action_label, score_label, vid_idx = next(
                    batch_iterators_neg_train[curr_domain])
            except StopIteration:
                batch_iterators_neg_train[curr_domain] = iter(
                    data_loaders_neg_train[curr_domain])
                images, bbox, action_label, score_label, vid_idx = next(
                    batch_iterators_neg_train[curr_domain])

        # TODO: check if this requires grad is really false like in Variable
        if args.cuda:
            images = images.to('cuda', non_blocking=True)
            # bbox = torch.Tensor(bbox.cuda())
            action_label = action_label.to('cuda', non_blocking=True)
            score_label = score_label.float().to('cuda', non_blocking=True)

        else:
            images = torch.Tensor(images)
            bbox = torch.Tensor(bbox)
            action_label = torch.Tensor(action_label)
            score_label = torch.Tensor(score_label)

        # TRAIN
        net.train()
        action_out, score_out = net(images)

        # load ADNetDomainSpecific with video index
        if args.cuda:
            net.module.load_domain_specific(domain_specific_nets[curr_domain])
        else:
            net.load_domain_specific(domain_specific_nets[curr_domain])

        # backprop
        optimizer.zero_grad()
        score_l = score_criterion(score_out, score_label.reshape(-1, 1))
        if which_dataset[iteration % len(which_dataset)]:  # if positive
            action_l = action_criterion(action_out,
                                        torch.max(action_label, 1)[1])
            accuracy = int(
                action_label.argmax(axis=1).eq(
                    action_out.argmax(axis=1)).sum()) / len(action_label)
            loss = action_l + score_l
        else:
            action_l = -1
            accuracy = -1
            loss = score_l
        loss.backward()
        optimizer.step()

        if action_l != -1:
            action_loss_tr += action_l.item()
        score_loss_tr += score_l.item()

        # save the ADNetDomainSpecific back to their module
        if args.cuda:
            domain_specific_nets[curr_domain].load_weights_from_adnet(
                net.module)
        else:
            domain_specific_nets[curr_domain].load_weights_from_adnet(net)

        if args.visualize:
            if action_l != -1:
                writer.add_scalars(
                    'data/iter_loss', {
                        'action_loss_tr': action_l.item(),
                        'score_loss_tr': score_l.item(),
                        'total_tr': (action_l.item() + score_l.item())
                    },
                    global_step=iteration)
            else:
                writer.add_scalars('data/iter_loss', {
                    'score_loss_tr': score_l.item(),
                    'total_tr': score_l.item()
                },
                                   global_step=iteration)
            if accuracy >= 0:
                writer.add_scalars('data/iter_acc', {'accuracy_tr': accuracy},
                                   global_step=iteration)

        t1 = time.time()
        time_arr[iteration % 10] = t1 - t0

        if iteration % 10 == 0:
            # print('Avg. 10 iter time: %.4f sec.' % time_arr.sum())
            # print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data.item()), end=' ')
            if args.visualize and args.send_images_to_visualization:
                random_batch_index = np.random.randint(images.size(0))
                writer.add_image('image',
                                 images.data[random_batch_index].cpu().numpy(),
                                 random_batch_index)

        if args.visualize:
            writer.add_scalars('data/time', {'time_10_it': time_arr.sum()},
                               global_step=iteration)

        if iteration % 5000 == 0:
            print('Saving state, iter:', iteration)

            domain_specific_nets_state_dict = []
            for domain_specific_net in domain_specific_nets:
                domain_specific_nets_state_dict.append(
                    domain_specific_net.state_dict())

            torch.save(
                {
                    'epoch': epoch,
                    'adnet_state_dict': net.state_dict(),
                    'adnet_domain_specific_state_dict': domain_specific_nets,
                    'optimizer_state_dict': optimizer.state_dict(),
                },
                os.path.join(args.save_folder, args.save_file) +
                repr(iteration) + '_epoch' + repr(epoch) + '.pth')

    # final save
    torch.save(
        {
            'epoch': epoch,
            'adnet_state_dict': net.state_dict(),
            'adnet_domain_specific_state_dict': domain_specific_nets,
            'optimizer_state_dict': optimizer.state_dict(),
        },
        os.path.join(args.save_folder, args.save_file) + '.pth')

    return net, domain_specific_nets, train_videos
Example #9
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # log_dir = os.path.expanduser(args.log_dir)
    log_dir=args.log_dir
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    save_path = os.path.join(args.save_dir)
    try:
        os.makedirs(save_path)
    except OSError:
        pass

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    # envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
    #                      args.gamma, args.log_dir, device, False)
    env = gym.make(args.env_name).unwrapped

    parser = argparse.ArgumentParser()
    parser.add_argument('--eval_imgs', default=0, type=int,
                        help='the num of imgs that picked from val.txt, 0 represent all imgs')
    parser.add_argument('--gt_skip', default=1, type=int, help='frame sampling frequency')
    parser.add_argument('--dataset_year', default=2222, type=int,
                        help='dataset version, like ILSVRC2015, ILSVRC2017, 2222 means train.txt')
    args2 = parser.parse_args(['--eval_imgs', '0', '--gt_skip', '1', '--dataset_year', '2222'])

    videos_infos, _ = get_ILSVRC_eval_infos(args2)

    mean = np.array(opts['means'], dtype=np.float32)
    mean = torch.from_numpy(mean).cuda()
    transform = ADNet_Augmentation2(opts, mean)

    # for en in envs:
    #     en.init_data(videos_infos, opts, transform, do_action,overlap_ratio)
    env.init_data(videos_infos, opts, transform, do_action, overlap_ratio)

    net, _ = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True,
                                      multidomain=False)
    # net = net.cuda()

    actor_critic = Policy(
        env.observation_space.shape,
        env.action_space,
        base=net,
        base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    # if args.algo == 'a2c':
    #     agent = algo.A2C_ACKTR(
    #         actor_critic,
    #         args.value_loss_coef,
    #         args.entropy_coef,
    #         lr=args.lr,
    #         eps=args.eps,
    #         alpha=args.alpha,
    #         max_grad_norm=args.max_grad_norm)
    # elif args.algo == 'ppo':
    agent = algo.PPO(
        actor_critic,
        args.clip_param,
        args.ppo_epoch,
        args.num_mini_batch,
        args.value_loss_coef,
        args.entropy_coef,
        lr=args.lr,
        eps=args.eps,
        max_grad_norm=args.max_grad_norm)
    # elif args.algo == 'acktr':
    #     agent = algo.A2C_ACKTR(
    #         actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True)

    # if args.gail:
    #     assert len(envs.observation_space.shape) == 1
    #     discr = gail.Discriminator(
    #         envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
    #         device)
    #     file_name = os.path.join(
    #         args.gail_experts_dir, "trajs_{}.pt".format(
    #             args.env_name.split('-')[0].lower()))
    #
    #     expert_dataset = gail.ExpertDataset(
    #         file_name, num_trajectories=4, subsample_frequency=20)
    #     drop_last = len(expert_dataset) > args.gail_batch_size
    #     gail_train_loader = torch.utils.data.DataLoader(
    #         dataset=expert_dataset,
    #         batch_size=args.gail_batch_size,
    #         shuffle=True,
    #         drop_last=drop_last)

    rollouts = RolloutStorage(args.num_steps,
                              opts['inputSize_transpose'], env.action_space,
                              )

    # episode_rewards = deque(maxlen=10)
    #
    # start = time.time()
    # num_updates = int(
    #     args.num_env_steps) // args.num_steps // args.num_processes
    # for j in range(num_updates):
    # epoch=0
    for epoch in range(0, args.num_epoch):
        env.reset_env()
        rollouts.reset_storage()
        obs = env.reset()
        rollouts.obs[0].copy_(obs)
        rollouts.to(device)
        j=-1

        va = 0
        n_va = 0
        va_epoch = 0
        n_va_epoch = 0
        while True:
            j+=1    #current clip number
            actor_critic.base.reset_action_dynamic()

            if args.use_linear_lr_decay:
                # decrease learning rate linearly
                # utils.update_linear_schedule(
                #     agent.optimizer, j, num_updates,
                #     agent.optimizer.lr if args.algo == "acktr" else args.lr)
                utils.update_linear_schedule(
                    agent.optimizer, j, len(videos_infos),
                    agent.optimizer.lr if args.algo == "acktr" else args.lr)

            # for step in range(args.num_steps):
            box_history_clip = []
            t = 0
            step=0
            while True: #one clip
                # Sample actions
                with torch.no_grad():
                    value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                        rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step])
                    va+=value
                    n_va+=1

                # Obser reward and next obs
                obs, new_state,reward, done, infos = env.step(action)
                reward=torch.Tensor([reward])

                # for info in infos:
                #     if 'episode' in info.keys():
                #         episode_rewards.append(info['episode']['r'])

                # If done then clean the history of observations.
                # masks = torch.FloatTensor(
                #     [[0.0] if done_ else [1.0] for done_ in done])
                # bad_masks = torch.FloatTensor(
                #     [[0.0] if 'bad_transition' in info.keys() else [1.0]
                #      for info in infos])
                masks = torch.FloatTensor(
                    [1.0])
                bad_masks = torch.FloatTensor(
                    [1.0])
                rollouts.insert(obs, recurrent_hidden_states, action,
                                action_log_prob, value, reward, masks, bad_masks)

                if ((action != opts['stop_action']) and any(
                        (np.array(new_state).round() == x).all() for x in np.array(box_history_clip).round())):
                    action = opts['stop_action']
                    reward, done, finish_epoch = env.go_to_next_frame()
                    infos['finish_epoch'] = finish_epoch

                if t > opts['num_action_step_max']:
                    #todo: in this situation, reward/feedback should be punished.
                    action = opts['stop_action']
                    reward, done, finish_epoch = env.go_to_next_frame()
                    infos['finish_epoch'] = finish_epoch

                box_history_clip.append(list(new_state))

                t += 1

                if action == opts['stop_action']:#finish one frame
                    t = 0
                    box_history_clip = []
                    rollouts.obs[rollouts.get_step()].copy_(env.get_current_patch())

                if done:  # if finish the clip
                    # rollouts.obs[rollouts.get_step()].copy_(obs)
                    rollouts.obs[rollouts.get_step()].copy_(env.get_current_patch())
                    break

            with torch.no_grad():
                # next_value = actor_critic.get_value(
                #     rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                #     rollouts.masks[-1]).detach()
                next_value = actor_critic.get_value(
                    rollouts.obs[rollouts.get_step()], rollouts.recurrent_hidden_states[rollouts.get_step()],
                    rollouts.masks[rollouts.get_step()]).detach()

            # if args.gail:
            #     if j >= 10:
            #         envs.venv.eval()
            #
            #     gail_epoch = args.gail_epoch
            #     if j < 10:
            #         gail_epoch = 100  # Warm up
            #     for _ in range(gail_epoch):
            #         discr.update(gail_train_loader, rollouts,
            #                      utils.get_vec_normalize(envs)._obfilt)
            #
            #     for step in range(args.num_steps):
            #         rollouts.rewards[step] = discr.predict_reward(
            #             rollouts.obs[step], rollouts.actions[step], args.gamma,
            #             rollouts.masks[step])

            rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                     args.gae_lambda, args.use_proper_time_limits)

            value_loss, action_loss, dist_entropy = agent.update(rollouts)
            rollouts.obs[rollouts.get_step()].copy_(env.get_current_patch())

            rollouts.after_update()

            if n_va>=100:
                ave_va=va/n_va
                print("current clip: %d, n_va: %d, cur v: %.2f, cur ave v: %.2f"%(j,n_va,value,ave_va))
                va_epoch+=va
                n_va_epoch+=n_va
                va=0
                n_va=0

            if infos['finish_epoch']:
                ave_va_epoch=va_epoch/n_va_epoch
                print("epoch: %d, ave value of v: %.2f"%(epoch,ave_va_epoch))
                va_epoch=0
                n_va_epoch=0
                break

            # save for every interval-th episode or for the last epoch
            if (j % args.save_interval == 0
                    ) and args.save_dir != "":

                torch.save({
                    'epoch': epoch,
                    'adnet_state_dict': actor_critic.base.state_dict(),
                    # 'adnet_domain_specific_state_dict': domain_specific_nets,
                    # 'optimizer_state_dict': optimizer.state_dict(),
                }, os.path.join(save_path,'ADNet_RL_epoch' + repr(epoch) + "_" + repr(j) + '.pth'))
                # torch.save([
                #     actor_critic.base,
                #     getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                # ], os.path.join(save_path, args.env_name + ".pt"))

            # if j % args.log_interval == 0 and len(episode_rewards) > 1:
            #     total_num_steps = (j + 1) * args.num_processes * args.num_steps
            #     end = time.time()
            #     print(
            #         "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
            #         .format(j, total_num_steps,
            #                 int(total_num_steps / (end - start)),
            #                 len(episode_rewards), np.mean(episode_rewards),
            #                 np.median(episode_rewards), np.min(episode_rewards),
            #                 np.max(episode_rewards), dist_entropy, value_loss,
            #                 action_loss))

            # if (args.eval_interval is not None and len(episode_rewards) > 1
            #         and j % args.eval_interval == 0):
            #     ob_rms = utils.get_vec_normalize(envs).ob_rms
            #     evaluate(actor_critic, ob_rms, args.env_name, args.seed,
            #              args.num_processes, eval_log_dir, device)

    torch.save({
        'epoch': epoch,
        'adnet_state_dict': actor_critic.base.state_dict(),
        # 'adnet_domain_specific_state_dict': domain_specific_nets,
        # 'optimizer_state_dict': optimizer.state_dict(),
    }, os.path.join(save_path,'ADNet_RL_final.pth'))