Exemple #1
0
def get_activations(PATH,batch):
	
    '''
    
	PATH: path to a saved pretrained model
	batch: numpy array of images/stimuli of size (batch size X number of blocks X colors X number of frames X height X width)
	Output: a dictionary containing layer activations as tensors
    
	'''
    
	model = DPC_RNN(sample_size=48, 
                        num_seq=8, 
                        seq_len=5, 
                        network='resnet18', 
                        pred_step=3)

	checkpoint = torch.load(PATH)
	model = neq_load_customized(model, checkpoint['state_dict'])

	activations = collections.defaultdict(list)
	
	for name, m in model.named_modules():
    		if type(m)==nn.Conv3d:
        		print(name)
        		# partial to assign the layer name to each hook
        		m.register_forward_hook(partial(save_activation, name))

	for batch in dataset:
	out = model(batch)

	activations = {name: torch.cat(outputs, 0) for name, outputs in activations.items()}

	return activations
Exemple #2
0
def load_model(model, model_path):
    if os.path.isfile(model_path):
        print("=> loading resumed checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        model = neq_load_customized(model, checkpoint['state_dict'])
        print("=> loaded resumed checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
    else:
        print("[WARNING] no checkpoint found at '{}'".format(model_path))
    return model
Exemple #3
0
def main():
    torch.manual_seed(0)
    np.random.seed(0)
    global args
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    global cuda
    cuda = torch.device('cuda')

    ### dpc model ###
    if args.model == 'dpc-rnn':
        model = DPC_RNN(sample_size=args.img_dim,
                        num_seq=args.num_seq,
                        seq_len=args.seq_len,
                        network=args.net,
                        pred_step=args.pred_step)
    else:
        raise ValueError('wrong model!')

    model = nn.DataParallel(model)
    model = model.to(cuda)
    global criterion
    criterion = nn.CrossEntropyLoss()

    ### optimizer ###
    if args.train_what == 'last':
        for name, param in model.module.resnet.named_parameters():
            param.requires_grad = False
    else:
        pass  # train all layers

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    args.old_lr = None

    best_acc = 0
    global iteration
    iteration = 0

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1))
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            if not args.reset_lr:  # if didn't reset lr, load old optimizer
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                print('==== Change lr from %f to %f ====' %
                      (args.old_lr, args.lr))
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(
                args.pretrain))
            checkpoint = torch.load(args.pretrain,
                                    map_location=torch.device('cpu'))
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(
                args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load data ###
    if args.dataset == 'ucf101':  # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            RandomCrop(size=224, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])
    elif args.dataset == 'k400':  # designed for kinetics400, short size=150, rand crop to 128x128
        transform = transforms.Compose([
            RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0),
            RandomHorizontalFlip(consistent=True),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])

    train_loader = get_data(transform, 'train')
    val_loader = get_data(transform, 'val')

    # setup tools
    global de_normalize
    de_normalize = denorm()
    global img_path
    img_path, model_path = set_path(args)
    global writer_train
    try:  # old version
        writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train'))
    except:  # v1.7
        writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, train_acc, train_accuracy_list = train(
            train_loader, model, optimizer, epoch)
        val_loss, val_acc, val_accuracy_list = validate(
            val_loader, model, epoch)

        # save curve
        writer_train.add_scalar('global/loss', train_loss, epoch)
        writer_train.add_scalar('global/accuracy', train_acc, epoch)
        writer_val.add_scalar('global/loss', val_loss, epoch)
        writer_val.add_scalar('global/accuracy', val_acc, epoch)
        writer_train.add_scalar('accuracy/top1', train_accuracy_list[0], epoch)
        writer_train.add_scalar('accuracy/top3', train_accuracy_list[1], epoch)
        writer_train.add_scalar('accuracy/top5', train_accuracy_list[2], epoch)
        writer_val.add_scalar('accuracy/top1', val_accuracy_list[0], epoch)
        writer_val.add_scalar('accuracy/top3', val_accuracy_list[1], epoch)
        writer_val.add_scalar('accuracy/top5', val_accuracy_list[2], epoch)

        # save check_point
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'net': args.net,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'iteration': iteration
            },
            is_best,
            filename=os.path.join(model_path,
                                  'epoch%s.pth.tar' % str(epoch + 1)),
            keep_all=False)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
Exemple #4
0
def main():
    global args
    args = parser.parse_args()
    global cuda
    cuda = torch.device('cuda')

    if args.dataset == 'ucf101': args.num_class = 101
    elif args.dataset == 'hmdb51': args.num_class = 51

    if args.ensemble:

        def read_pkl(fname):
            if fname == '':
                return None
            with open(fname, 'rb') as f:
                prob = pickle.load(f)
            return prob

        ensemble(read_pkl(args.prob_imgs), read_pkl(args.prob_flow),
                 read_pkl(args.prob_seg), read_pkl(args.prob_kphm))
        sys.exit()

    args.in_channels = get_num_channels(args.modality)

    ### classifier model ###
    if args.model == 'lc':
        model = LC(sample_size=args.img_dim,
                   num_seq=args.num_seq,
                   seq_len=args.seq_len,
                   in_channels=args.in_channels,
                   network=args.net,
                   num_class=args.num_class,
                   dropout=args.dropout)
    else:
        raise ValueError('wrong model!')

    model = nn.DataParallel(model)
    model = model.to(cuda)
    global criterion
    criterion = nn.CrossEntropyLoss()

    ### optimizer ###
    params = None
    if args.train_what == 'ft':
        print('=> finetune backbone with smaller lr')
        params = []
        for name, param in model.module.named_parameters():
            if ('resnet' in name) or ('rnn' in name):
                params.append({'params': param, 'lr': args.lr / 10})
            else:
                params.append({'params': param})
    elif args.train_what == 'freeze':
        print('=> Freeze backbone')
        params = []
        for name, param in model.module.named_parameters():
            param.requires_grad = False
    else:
        pass  # train all layers

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        if param.requires_grad == False:
            print(name, param.requires_grad)
    print('=================================\n')

    if params is None:
        params = model.parameters()

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    # Old version
    # if args.dataset == 'hmdb51':
    #     lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50,70,90], repeat=1)
    # elif args.dataset == 'ucf101':
    #     if args.img_dim == 224: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[90,140,180], repeat=1)
    #     else: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50, 70, 90], repeat=1)
    if args.img_dim == 224:
        lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(
            ep, gamma=0.1, step=[60, 120, 180], repeat=1)
    else:
        lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(
            ep, gamma=0.1, step=[50, 70, 90], repeat=1)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    args.old_lr = None
    best_acc = 0
    global iteration
    iteration = 0
    global num_epoch
    num_epoch = 0

    ### restart training ###
    if args.test:
        if os.path.isfile(args.test):
            print("=> loading testing checkpoint '{}'".format(args.test))
            checkpoint = torch.load(args.test)
            try:
                model.load_state_dict(checkpoint['state_dict'])
            except:
                print(
                    '=> [Warning]: weight structure is not equal to test model; Use non-equal load =='
                )
                model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded testing checkpoint '{}' (epoch {})".format(
                args.test, checkpoint['epoch']))
            num_epoch = checkpoint['epoch']
        elif args.test == 'random':
            print("=> [Warning] loaded random weights")
        else:
            raise ValueError()

        test_loader = get_data_loader(args, 'test')
        test_loss, test_acc = test(test_loader,
                                   model,
                                   extensive=args.extensive)
        sys.exit()
    else:  # not test
        torch.backends.cudnn.benchmark = True

    if args.resume:
        if os.path.isfile(args.resume):
            # args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1))
            args.old_lr = 1e-3
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            if not args.reset_lr:  # if didn't reset lr, load old optimizer
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                print('==== Change lr from %f to %f ====' %
                      (args.old_lr, args.lr))
            iteration = checkpoint['iteration']
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if (not args.resume) and args.pretrain:
        if args.pretrain == 'random':
            print('=> using random weights')
        elif os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(
                args.pretrain))
            checkpoint = torch.load(args.pretrain,
                                    map_location=torch.device('cpu'))
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(
                args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load data ###
    train_loader = get_data_loader(args, 'train')
    val_loader = get_data_loader(args, 'val')
    test_loader = get_data_loader(args, 'test')

    # setup tools
    global de_normalize
    de_normalize = denorm()
    global img_path
    img_path, model_path = set_path(args)
    global writer_train
    try:  # old version
        writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train'))
    except:  # v1.7
        writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))

    args.test = model_path
    print("Model path:", model_path)

    # Freeze the model backbone initially
    model = freeze_backbone(model)
    cooldown = 0

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        num_epoch = epoch

        train_loss, train_acc = train(train_loader, model, optimizer, epoch)
        val_loss, val_acc = validate(val_loader, model)
        scheduler.step(epoch)

        writer_train.add_scalar('global/loss', train_loss, epoch)
        writer_train.add_scalar('global/accuracy', train_acc, epoch)
        writer_val.add_scalar('global/loss', val_loss, epoch)
        writer_val.add_scalar('global/accuracy', val_acc, epoch)

        # save check_point
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        # Perform testing if either the frequency is hit or the model is the best after a few epochs
        if (epoch + 1) % args.full_eval_freq == 0:
            test(test_loader, model)
        elif (epoch > 70) and (cooldown >= 5) and is_best:
            test(test_loader, model)
            cooldown = 0
        else:
            cooldown += 1

        save_checkpoint(state={
            'epoch': epoch + 1,
            'net': args.net,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
            'iteration': iteration
        },
                        mode=args.modality,
                        is_best=is_best,
                        gap=5,
                        filename=os.path.join(
                            model_path, 'epoch%s.pth.tar' % str(epoch + 1)),
                        keep_all=False)

        # Unfreeze the model backbone after the first run
        if epoch == (args.start_epoch):
            model = unfreeze_backbone(model)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
    print("Model path:", model_path)
Exemple #5
0
def main():
    torch.manual_seed(0)
    np.random.seed(0)
    global args
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    global cuda
    cuda = torch.device('cuda')

    ### output file directory ###
    global pkl_folder
    if args.pkl_folder_name != '':  # allow user to explicitly specify output dir; otherwise follow default in the below
        pkl_folder = args.pkl_folder_name + '/pred_err'
    else:
        if args.dataset == 'tapos_instances':
            pkl_folder = '../../data/exp_TAPOS/' + args.pred_task + '/pred_err'
        elif args.dataset == 'k400':
            pkl_folder = '../../data/exp_k400/' + args.pred_task + '/pred_err'
        args.pkl_folder_name = pkl_folder[:-9]

    if not os.path.exists(pkl_folder):
        os.makedirs(pkl_folder)

    ### dpc model ###
    if args.model == 'dpc-rnn':
        model = DPC_RNN_Infer_Pred_Error(sample_size=args.img_dim,
                                         num_seq=args.num_seq,
                                         seq_len=args.seq_len,
                                         network=args.net,
                                         pred_step=args.pred_step)
    elif args.model == 'resnet50-beforepool':
        model = ResNet50_BeforePool_Feature_Extractor(pretrained=True,
                                                      num_seq=args.num_seq,
                                                      seq_len=args.seq_len,
                                                      pred_step=args.pred_step)
    elif args.model == 'resnet18-beforepool':
        model = ResNet18_BeforePool_Feature_Extractor(pretrained=True,
                                                      num_seq=args.num_seq,
                                                      seq_len=args.seq_len,
                                                      pred_step=args.pred_step)
    elif args.model == 'rgb-avg-temporal':
        model = RBG_Extractor(pretrained=True,
                              num_seq=args.num_seq,
                              seq_len=args.seq_len,
                              pred_step=args.pred_step)
    else:
        raise ValueError('wrong model!')

    model = nn.DataParallel(model)
    model = model.to(cuda)
    global criterion
    criterion = nn.MSELoss(reduction='none')

    print('\n===========No grad for all layers============')
    for name, param in model.named_parameters():
        param.requires_grad = False
        print(name, param.requires_grad)
    print('==============================================\n')

    params = model.parameters()

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(
                args.pretrain))
            checkpoint = torch.load(args.pretrain,
                                    map_location=torch.device('cpu'))
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(
                args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load data ###
    transform = transforms.Compose(
        [Scale(size=(args.img_dim, args.img_dim)),
         ToTensor(),
         Normalize()])

    val_loader = get_data(transform, args.mode)

    ### main loop ###
    validate(val_loader, model)
Exemple #6
0
def main():
    torch.manual_seed(0)
    np.random.seed(0)
    global args;
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    global cuda;
    cuda = torch.device('cuda')

    model = Audio_RNN(img_dim=args.img_dim, network=args.net, num_layers_in_fc_layers=args.final_dim,
                      dropout=args.dropout)

    model = nn.DataParallel(model)
    model = model.to(cuda)
    global criterion;
    criterion = nn.CrossEntropyLoss()

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    least_loss = 0
    global iteration;
    iteration = 0

    args.old_lr = None

    if args.test:
        if os.path.isfile(args.test):
            print("=> loading testing checkpoint '{}'".format(args.test))
            checkpoint = torch.load(args.test)
            try:
                model.load_state_dict(checkpoint['state_dict'])
            except:
                print('=> [Warning]: weight structure is not equal to test model; Use non-equal load ==')
                sys.exit()
            print("=> loaded testing checkpoint '{}' (epoch {})".format(args.test, checkpoint['epoch']))
            global num_epoch;
            num_epoch = checkpoint['epoch']
        elif args.test == 'random':
            print("=> [Warning] loaded random weights")
        else:
            raise ValueError()

        transform = transforms.Compose([
        Scale(size=(args.img_dim,args.img_dim)),
        ToTensor(),
        Normalize()
        ])
        test_loader = get_data(transform, 'test', 1)
        global test_dissimilarity_score; test_dissimilarity_score = {}
        global test_target; test_target = {}
        global test_number_ofile_number_of_chunks; test_number_ofile_number_of_chunks = {}
        test_loss = test(test_loader, model)
        file_dissimilarity_score = open("file_dissimilarity_score.pkl","wb")
        pickle.dump(test_dissimilarity_score,file_dissimilarity_score)
        file_dissimilarity_score.close()
        file_target = open("file_target.pkl","wb")
        pickle.dump(test_target,file_target)
        file_target.close()
        file_number_of_chunks = open("file_number_of_chunks.pkl","wb")
        pickle.dump(test_number_ofile_number_of_chunks,file_number_of_chunks)
        file_number_of_chunks.close()

        print('get train threshold!!!!')
        transform = transforms.Compose([
            Scale(size=(args.img_dim, args.img_dim)),
            ToTensor(),
            Normalize()
        ])
        train_loader = get_data(transform, 'train', 0)
        global train_dissimilarity_score;
        train_dissimilarity_score = {}
        global train_target;
        train_target = {}
        global train_number_ofile_number_of_chunks;
        train_number_ofile_number_of_chunks = {}

        train_loss = get_threshold(train_loader, model)

        file_dissimilarity_score = open("train_dissimilarity_score.pkl", "wb")
        pickle.dump(train_dissimilarity_score, file_dissimilarity_score)
        file_dissimilarity_score.close()
        file_target = open("train_target.pkl", "wb")
        pickle.dump(train_target, file_target)
        file_target.close()
        file_number_of_chunks = open("train_number_ofile_number_of_chunks.pkl", "wb")
        pickle.dump(train_number_ofile_number_of_chunks, file_number_of_chunks)
        file_number_of_chunks.close()


        sys.exit()
    else:  # not test
        torch.backends.cudnn.benchmark = True

    if args.resume:
        if os.path.isfile(args.resume):
            # args.old_lr = float(re.search('_lr(.+?)/', args.resume).group(1))
            # print(args.old_lr)
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            least_loss = checkpoint['least_loss']
            try:
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])

            except:
                print('=> [Warning]: weight structure is not equal to checkpoint; Use non-equal load ==')
                model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))

    transform = transforms.Compose([
        Scale(size=(args.img_dim, args.img_dim)),
        ToTensor(),
        Normalize()
    ])

    train_loader = get_data(transform, 'train', 1)

    global total_sum_scores_real;
    total_sum_scores_real = 0
    global total_sum_scores_fake;
    total_sum_scores_fake = 0
    global count_real;
    count_real = 0
    global count_fake;
    count_fake = 0
    # setup tools
    global de_normalize;
    de_normalize = denorm()
    global img_path;
    img_path, model_path = set_path(args)
    global writer_train
    try:  # old version
        writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train'))
    except:  # v1.7
        writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))
    
    ### main loop ###
    epoch_losses = np.zeros(args.epochs)
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, avg_score_real, avg_score_fake = train(train_loader, model, optimizer, epoch)
        epoch_losses[epoch] = train_loss
        writer_train.add_scalar('global/loss', train_loss, epoch)
        writer_train.add_scalar('global/avg_score_fake', avg_score_fake, epoch)
        writer_train.add_scalar('global/avg_score_real', avg_score_real, epoch)

        # save check_point
        if epoch == 0:
            least_loss = train_loss
        is_best = train_loss <= least_loss
        least_loss = min(least_loss, train_loss)
        save_checkpoint({
            'epoch': epoch + 1,
            'net': args.net,
            'state_dict': model.state_dict(),
            'least_loss': least_loss,
            'avg_score_real': avg_score_real,
            'avg_score_fake': avg_score_fake,
            'optimizer': optimizer.state_dict(),
            'iteration': iteration
        }, is_best, filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch + 1)), keep_all=True)

    print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs))
    for i in range(epoch_losses):
        print('Epoch:',i,'Loss:',epoch_losses[i])
Exemple #7
0
def main():

    # set to constant for consistant results
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    global args
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    global cuda
    cuda = torch.device('cuda')

    ### dpc model ###
    if args.model == 'dpc-rnn':
        model = DPC_RNN(sample_size=args.img_dim,
                        num_seq=args.num_seq,
                        seq_len=args.seq_len,
                        network=args.net,
                        pred_step=args.pred_step,
                        distance=args.distance,
                        distance_type=args.distance_type,
                        weighting=args.weighting,
                        margin=args.margin,
                        pool=args.pool,
                        loss_type=args.loss_type)
    else:
        raise ValueError('wrong model!')

    # parallelize the model
    model = nn.DataParallel(model)
    model = model.to(cuda)

    # load dict [frame_id -> action]
    if args.action_from_frame == 'True':
        with open('/proj/vondrick/ruoshi/github/TPC/tpc/action_from_frame.p',
                  'rb') as fp:
            map_action_frame = pickle.load(fp)

    ### optimizer ###
    # dont' think we need to use 'last' keyword during pre-training anywhere
    if args.train_what == 'last':
        for name, param in model.module.resnet.named_parameters():
            param.requires_grad = False
    else:
        pass  # train all layers

    # check if gradient flowing to appropriate layers
    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    args.old_lr = None

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5)

    best_acc = 0
    global iteration
    iteration = 0

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            # load the old model and set the learning rate accordingly

            # get the old learning rate
            args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1))
            print("=> loading resumed checkpoint '{}'".format(args.resume))

            #
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])

            # load old optimizer, start with the corresponding learning rate
            if not args.reset_lr:  # if didn't reset lr, load old optimizer
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                # reset to new learning rate
                print('==== Change lr from %f to %f ====' %
                      (args.old_lr, args.lr))
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(
                args.pretrain))
            checkpoint = torch.load(args.pretrain,
                                    map_location=torch.device('cpu'))

            # neq_load_customized
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(
                args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load transform & data ###
    if 'epic' in args.dataset:
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            # RandomSizedCrop(size=224, consistent=True, p=1.0),
            RandomCrop(size=224, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])
    elif 'ucf' in args.dataset:  # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            RandomCrop(size=224, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])
    elif args.dataset == 'k400':  # designed for kinetics400, short size=150, rand crop to 128x128
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])
    elif args.dataset in ['block_toy_1', 'block_toy_2', 'block_toy_3']:
        # get validation data with label for plotting the embedding
        transform = transforms.Compose([
            # centercrop
            CenterCrop(size=224),
            # RandomSizedCrop(consistent=True, size=224, p=0.0), # no effect
            Scale(size=(args.img_dim, args.img_dim)),
            ToTensor(),
            Normalize()
        ])
    elif args.dataset in [
            'block_toy_imagenet_1', 'block_toy_imagenet_2',
            'block_toy_imagenet_3'
    ]:
        # may have to introduce more transformations with imagenet background
        transform = transforms.Compose([
            # centercrop
            CenterCrop(size=224),
            # RandomSizedCrop(consistent=True, size=224, p=0.0), # no effect
            Scale(size=(args.img_dim, args.img_dim)),
            ToTensor(),
            Normalize()
        ])

    train_loader = get_data(transform, mode='train')
    val_loader = get_data(transform, mode='val')

    # setup tools

    # denormalize to display input images via tensorboard
    global de_normalize
    de_normalize = denorm()

    global img_path
    img_path, model_path = set_path(args)
    global writer_train

    # book-keeping
    try:  # old version
        writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train'))
    except:  # v1.7
        writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, train_acc, train_accuracy_list, radius, radius_var = train(
            train_loader, model, optimizer, epoch)
        val_loss, val_acc, val_accuracy_list, radius, radius_var = validate(
            val_loader, model, epoch)

        #         scheduler.step(val_loss)

        # save curve
        writer_train.add_scalar('global/loss', train_loss, epoch)
        writer_train.add_scalar('global/accuracy', train_acc, epoch)
        writer_train.add_scalar('global/radius', radius, epoch)
        writer_train.add_scalar('global/radius_var', radius_var, epoch)

        writer_val.add_scalar('global/loss', val_loss, epoch)
        writer_val.add_scalar('global/accuracy', val_acc, epoch)
        writer_val.add_scalar('global/radius', radius, epoch)
        writer_val.add_scalar('global/radius_var', radius_var, epoch)

        writer_train.add_scalar('accuracy/top1', train_accuracy_list[0], epoch)
        writer_train.add_scalar('accuracy/top3', train_accuracy_list[1], epoch)
        writer_train.add_scalar('accuracy/top5', train_accuracy_list[2], epoch)
        writer_val.add_scalar('accuracy/top1', val_accuracy_list[0], epoch)
        writer_val.add_scalar('accuracy/top3', val_accuracy_list[1], epoch)
        writer_val.add_scalar('accuracy/top5', val_accuracy_list[2], epoch)

        # save check_point
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'net': args.net,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'iteration': iteration
            },
            is_best,
            filename=os.path.join(model_path,
                                  'epoch%s.pth.tar' % str(epoch + 1)),
            save_every=10)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
Exemple #8
0
def main():
    global args; args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)
    global cuda; cuda = torch.device('cuda')

    if args.dataset == 'ucf101': args.num_class = 101
    elif args.dataset == 'hmdb51': args.num_class = 51 

    ### classifier model ###
    if args.model == 'lc':
        model = LC(sample_size=args.img_dim, 
                   num_seq=args.num_seq, 
                   seq_len=args.seq_len, 
                   network=args.net,
                   num_class=args.num_class,
                   dropout=args.dropout)
    else:
        raise ValueError('wrong model!')

    model = nn.DataParallel(model)
    model = model.to(cuda)
    global criterion; criterion = nn.CrossEntropyLoss()
    
    ### optimizer ### 
    params = None
    if args.train_what == 'ft':
        print('=> finetune backbone with smaller lr')
        params = []
        for name, param in model.module.named_parameters():
            if ('resnet' in name) or ('rnn' in name):
                params.append({'params': param, 'lr': args.lr/10})
            else:
                params.append({'params': param})
    else: pass # train all layers
    
    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    if params is None: params = model.parameters()

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    if args.dataset == 'hmdb51':
        lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[150,250,300], repeat=1)
    elif args.dataset == 'ucf101':
        if args.img_dim == 224: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[300,400,500], repeat=1)
        else: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[60, 80, 100], repeat=1)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    args.old_lr = None
    best_acc = 0
    global iteration; iteration = 0

    ### restart training ###
    if args.test:
        if os.path.isfile(args.test):
            print("=> loading testing checkpoint '{}'".format(args.test))
            checkpoint = torch.load(args.test)
            try: model.load_state_dict(checkpoint['state_dict'])
            except:
                print('=> [Warning]: weight structure is not equal to test model; Use non-equal load ==')
                model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded testing checkpoint '{}' (epoch {})".format(args.test, checkpoint['epoch']))
            global num_epoch; num_epoch = checkpoint['epoch']
        elif args.test == 'random':
            print("=> [Warning] loaded random weights")
        else: 
            raise ValueError()

        transform = transforms.Compose([
            RandomSizedCrop(consistent=True, size=224, p=0.0),
            Scale(size=(args.img_dim,args.img_dim)),
            ToTensor(),
            Normalize()
        ])
        test_loader = get_data(transform, 'test')
        test_loss, test_acc = test(test_loader, model)
        sys.exit()
    else: # not test
        torch.backends.cudnn.benchmark = True

    if args.resume:
        if os.path.isfile(args.resume):
            args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1))
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            if not args.reset_lr: # if didn't reset lr, load old optimizer
                optimizer.load_state_dict(checkpoint['optimizer'])
            else: print('==== Change lr from %f to %f ====' % (args.old_lr, args.lr))
            iteration = checkpoint['iteration']
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if (not args.resume) and args.pretrain:
        if args.pretrain == 'random':
            print('=> using random weights')
        elif os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(args.pretrain))
            checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu'))
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load data ###
    transform = transforms.Compose([
        RandomSizedCrop(consistent=True, size=224, p=1.0),
        Scale(size=(args.img_dim,args.img_dim)),
        RandomHorizontalFlip(consistent=True),
        ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True),
        ToTensor(),
        Normalize()
    ])
    val_transform = transforms.Compose([
        RandomSizedCrop(consistent=True, size=224, p=0.3),
        Scale(size=(args.img_dim,args.img_dim)),
        RandomHorizontalFlip(consistent=True),
        ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True),
        ToTensor(),
        Normalize()
    ])

    train_loader = get_data(transform, 'train')
    val_loader = get_data(val_transform, 'val')

    # setup tools
    global de_normalize; de_normalize = denorm()
    global img_path; img_path, model_path = set_path(args)
    global writer_train
    try: # old version
        writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train'))
    except: # v1.7
        writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, train_acc = train(train_loader, model, optimizer, epoch)
        val_loss, val_acc = validate(val_loader, model)
        scheduler.step(epoch)

        writer_train.add_scalar('global/loss', train_loss, epoch)
        writer_train.add_scalar('global/accuracy', train_acc, epoch)
        writer_val.add_scalar('global/loss', val_loss, epoch)
        writer_val.add_scalar('global/accuracy', val_acc, epoch)

        # save check_point
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint({
            'epoch': epoch+1,
            'net': args.net,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
            'iteration': iteration
        }, is_best, filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch+1)), keep_all=False)
    
    print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs))
Exemple #9
0
def main():
    torch.manual_seed(0)
    np.random.seed(0)
    global args;
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # NVIDIA-SMI uses PCI_BUS_ID device order, but CUDA orders graphics devices by speed by default (fastest first).
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(id) for id in args.gpu])
    
    
    print ('Cuda visible devices: {}'.format(os.environ["CUDA_VISIBLE_DEVICES"]))
    print ('Available device count: {}'.format(torch.cuda.device_count()))

    args.gpu = list(range(torch.cuda.device_count()))  # Really weird: In Pytorch 1.2, the device ids start from 0 on the visible devices.

    print("Note: At least in Pytorch 1.2, device ids are reindexed on the visible devices and not the same as in nvidia-smi.")

    for i in args.gpu:
        print("Using Cuda device {}: {}".format(i, torch.cuda.get_device_name(i)))
    print("Cuda is available: {}".format(torch.cuda.is_available()))
    global cuda;
    cuda = torch.device('cuda')
    
    ### dpc model ###
    if args.model == 'dpc-rnn':
        model = DPC_RNN(sample_size=args.img_dim,
                        num_seq=args.num_seq,
                        seq_len=args.seq_len,
                        network=args.net,
                        pred_step=args.pred_step)
    else: raise ValueError('wrong model!')

    # Data Parallel uses a master device (default gpu 0) and performs scatter gather operations on batches and resulting gradients.
    model = nn.DataParallel(model)  # Distributes batches on mutiple devices to train model in parallel automatically.
    model = model.to(cuda)  # Sends model to device 0, other gpus are used automatically.
    global criterion
    criterion = nn.CrossEntropyLoss()  # Contrastive loss is basically CrossEntropyLoss with vector similarity and temperature.

    ### optimizer ###
    if args.train_what == 'last':
        for name, param in model.module.resnet.named_parameters():
            param.requires_grad = False
    else: pass  # train all layers

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    args.old_lr = None

    best_acc = 0
    global iteration
    iteration = 0

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1))
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            # I assume this copies the *cpu located* parameters to the CUDA model automatically?
            model.load_state_dict(checkpoint['state_dict'])
            if not args.reset_lr:  # if didn't reset lr, load old optimizer
                optimizer.load_state_dict(checkpoint['optimizer'])
            else: print('==== Change lr from %f to %f ====' % (args.old_lr, args.lr))
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(args.pretrain))
            checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu'))
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})"
                  .format(args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load data ###
    if args.dataset == 'ucf101':  # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            RandomCrop(size=224, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0),
            ToTensor(),
            Normalize()
        ])
    elif args.dataset == 'k400':  # designed for kinetics400, short size=150, rand crop to 128x128
        transform = transforms.Compose([
            RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0),
            RandomHorizontalFlip(consistent=True),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0),
            ToTensor(),
            Normalize()
        ])
    elif args.dataset == 'nturgbd':  # designed for nturgbd, short size=150, rand crop to 128x128
        transform = transforms.Compose([
            RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0),
            RandomHorizontalFlip(consistent=True),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0),
            ToTensor(),
            Normalize()
        ])


    train_loader = get_data(transform, 'train')
    val_loader = get_data(transform, 'val')

    # setup tools
    global de_normalize;
    de_normalize = denorm()
    global img_path;
    img_path, model_path = set_path(args)
    global writer_train
    try:  # old version
        writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train'))
    except:  # v1.7
        writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, train_acc, train_accuracy_list = train(train_loader, model, optimizer, epoch)
        val_loss, val_acc, val_accuracy_list = validate(val_loader, model, epoch)

        # save curve
        writer_train.add_scalar('global/loss', train_loss, epoch)
        writer_train.add_scalar('global/accuracy', train_acc, epoch)
        writer_val.add_scalar('global/loss', val_loss, epoch)
        writer_val.add_scalar('global/accuracy', val_acc, epoch)
        writer_train.add_scalar('accuracy/top1', train_accuracy_list[0], epoch)
        writer_train.add_scalar('accuracy/top3', train_accuracy_list[1], epoch)
        writer_train.add_scalar('accuracy/top5', train_accuracy_list[2], epoch)
        writer_val.add_scalar('accuracy/top1', val_accuracy_list[0], epoch)
        writer_val.add_scalar('accuracy/top3', val_accuracy_list[1], epoch)
        writer_val.add_scalar('accuracy/top5', val_accuracy_list[2], epoch)

        # save check_point
        is_best = val_acc > best_acc;
        best_acc = max(val_acc, best_acc)
        save_checkpoint({'epoch': epoch + 1,
                         'net': args.net,
                         'state_dict': model.state_dict(),
                         'best_acc': best_acc,
                         'optimizer': optimizer.state_dict(),
                         'iteration': iteration},
                        is_best, filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch + 1)), keep_all=False)

    print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs))
Exemple #10
0
def main():

    global args
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    # global cuda; cuda = torch.device('cuda') # uncomment this if only gpu
    # added by Shahab
    global cuda
    if torch.cuda.is_available():
        cuda = torch.device('cuda')
    else:
        cuda = torch.device('cpu')

    ### dpc model ###
    if args.model == 'dpc-rnn':
        model = DPC_RNN(sample_size=args.img_dim,
                        num_seq=args.num_seq,
                        seq_len=args.seq_len,
                        network=args.net,
                        pred_step=args.pred_step)
    elif args.model == 'dpc-plus':
        model = DPC_Plus(sample_size=args.img_dim,
                         num_seq=args.num_seq,
                         seq_len=args.seq_len,
                         network=args.net,
                         pred_step=args.pred_step)

    else:
        raise ValueError('wrong model!')

    model = nn.DataParallel(model)
    model = model.to(cuda)
    global criterion
    criterion = nn.CrossEntropyLoss()
    global criterion_aux
    global temperature
    temperature = 1

    if args.wandb:
        wandb.init(f"CPC {args.prefix}", config=args)
        wandb.watch(model)

    ### optimizer ###
    if args.train_what == 'last':
        for name, param in model.module.resnet.named_parameters():
            param.requires_grad = False
    else:
        pass  # train all layers

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    # setting additional criterions
    if args.target == 'obj_categ' and (args.dataset == 'tdw'
                                       or args.dataset == 'cifar10'):
        criterion_aux = nn.CrossEntropyLoss()
    elif args.target == 'self_motion':
        criterion_aux = nn.MSELoss(reduction='sum')


#         criterion_aux = nn.L1Loss(reduction = 'sum')
    elif args.target == 'act_recog' and args.dataset == 'ucf101':
        criterion_aux = nn.CrossEntropyLoss()
    else:
        raise NotImplementedError(
            f"{args.target} is not a valid target variable or the selected dataset doesn't support this target variable"
        )

    args.old_lr = None

    best_acc = 0
    best_loss = 1e10
    global iteration
    iteration = 0

    ### restart training ###
    global img_path
    img_path, model_path = set_path(args)
    if os.path.exists(os.path.join(img_path, 'last.pth.tar')):
        args.resume = os.path.join(img_path, 'last.pth.tar')
    else:
        pass

    if args.resume:
        if os.path.isfile(args.resume):
            args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1))
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            #             best_acc = checkpoint['best_acc']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            if not args.reset_lr:  # if didn't reset lr, load old optimizer
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                print('==== Change lr from %f to %f ====' %
                      (args.old_lr, args.lr))
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(
                args.pretrain))
            checkpoint = torch.load(args.pretrain,
                                    map_location=torch.device('cpu'))
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(
                args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load data ###
    if args.dataset == 'ucf101':  # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            RandomCrop(size=224, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])

    elif args.dataset == 'catcam':  # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            RandomCrop(size=224, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])

    elif args.dataset == 'k400':  # designed for kinetics400, short size=150, rand crop to 128x128
        transform = transforms.Compose([
            RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0),
            RandomHorizontalFlip(consistent=True),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])

    elif args.dataset == 'airsim':
        transform = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            RandomCrop(size=112, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            RandomGray(consistent=False, p=0.5),
            ColorJitter(brightness=0.5,
                        contrast=0.5,
                        saturation=0.5,
                        hue=0.25,
                        p=1.0),
            ToTensor(),
            Normalize()
        ])

    elif args.dataset == 'tdw':
        transform = transforms.Compose([
            #RandomHorizontalFlip(consistent=True),
            #RandomCrop(size=128, consistent=True),
            Scale(size=(args.img_dim, args.img_dim)),
            #RandomGray(consistent=False, p=0.5),
            #ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0),
            ToTensor(),
            Normalize(mean=[0.5036, 0.4681, 0.4737],
                      std=[0.2294, 0.2624, 0.2830])
        ])

    train_loader = get_data(transform, 'train')
    val_loader = get_data(transform, 'val')

    # setup tools
    global de_normalize
    de_normalize = denorm()

    global writer_train
    try:  # old version
        writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val'))

        writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train'))
    except:  # v1.7
        writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val'))
        writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))

    ### main loop ###
    save_checkpoint_freq = args.save_checkpoint_freq

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_acc, train_accuracy_list, train_loss_hd = train(
            train_loader, model, optimizer, epoch)

        val_loss, val_acc, val_accuracy_list, val_loss_hd = validate(
            val_loader, model, epoch)

        if args.wandb:
            wandb.log({
                "epoch": epoch,
                "cpc train loss": train_loss,
                "cpc train accuracy top1": train_accuracy_list[0],
                "cpc val loss": val_loss,
                "cpc val accuracy top1": val_accuracy_list[0],
                "heading train loss": train_loss_hd,
                "heading val loss": val_loss_hd
            })

        # save curve
        writer_train.add_scalar('global/loss', train_loss, epoch)
        writer_train.add_scalar('global/accuracy', train_acc, epoch)
        writer_val.add_scalar('global/loss', val_loss, epoch)
        writer_val.add_scalar('global/accuracy', val_acc, epoch)
        writer_train.add_scalar('accuracy/top1', train_accuracy_list[0], epoch)
        writer_train.add_scalar('accuracy/top3', train_accuracy_list[1], epoch)
        writer_train.add_scalar('accuracy/top5', train_accuracy_list[2], epoch)
        writer_val.add_scalar('accuracy/top1', val_accuracy_list[0], epoch)
        writer_val.add_scalar('accuracy/top3', val_accuracy_list[1], epoch)
        writer_val.add_scalar('accuracy/top5', val_accuracy_list[2], epoch)

        # save check_point
        is_best_loss = (val_loss + val_loss_hd) < best_loss
        best_loss = min(val_loss + val_loss_hd, best_loss)
        #         is_best = val_acc > best_acc; best_acc = max(val_acc, best_acc)
        if epoch % save_checkpoint_freq == 0:
            save_this = True
        else:
            save_this = False

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'net': args.net,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                #                          'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'iteration': iteration
            },
            is_best_loss,
            filename=os.path.join(model_path,
                                  'epoch%s.pth.tar' % str(epoch + 1)),
            keep_all=save_this)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'net': args.net,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                #                          'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'iteration': iteration
            },
            is_best_loss,
            filename=os.path.join(model_path, 'last.pth.tar'),
            keep_all=save_this)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
Exemple #11
0
def main():

    # Set constant random state for consistent results
    torch.manual_seed(704)
    np.random.seed(704)
    random.seed(704)

    global args
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    global cuda
    cuda = torch.device('cuda')
    args.cuda = cuda

    ### dpc model ###
    if args.model == 'cvae':
        model = DPC_CVAE(img_dim=args.img_dim,
                         num_seq=args.num_seq,
                         seq_len=args.seq_len,
                         pred_step=args.pred_step,
                         network=args.net,
                         cvae_arch=args.cvae_arch)
    elif 'vrnn' in args.model:
        model = DPC_VRNN(img_dim=args.img_dim,
                         num_seq=args.num_seq,
                         seq_len=args.seq_len,
                         pred_step=args.pred_step,
                         network=args.net,
                         latent_size=args.vrnn_latent_size,
                         kernel_size=args.vrnn_kernel_size,
                         rnn_dropout=args.vrnn_dropout,
                         time_indep='-i' in args.model)
    else:
        raise ValueError('Unknown / wrong model: ' + args.model)

    # parallelize the model
    model = nn.DataParallel(model)
    model = model.to(cuda)

    # loss function (change this)
    # why cross-entropy, we can take a direct cosine distance instead
    criterion = nn.CrossEntropyLoss()

    ### optimizer ###
    # dont' think we need to use 'last' keyword during pre-training anywhere
    if args.train_what == 'last':
        for name, param in model.module.resnet.named_parameters():
            param.requires_grad = False
    else:
        pass  # train all layers

    # check if gradient flowing to appropriate layers
    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    args.old_lr = None

    best_acc = 0
    global iteration
    iteration = 0

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            # load the old model and set the learning rate accordingly

            # get the old learning rate
            args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1))
            print("=> loading resumed checkpoint '{}'".format(args.resume))

            #
            checkpoint = torch.load(
                args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])

            # load old optimizer, start with the corresponding learning rate
            if not args.reset_lr:  # if didn't reset lr, load old optimizer
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                # reset to new learning rate
                print('==== Change lr from %f to %f ====' %
                      (args.old_lr, args.lr))
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            print("=> loading pretrained checkpoint '{}'".format(args.pretrain))
            checkpoint = torch.load(
                args.pretrain, map_location=torch.device('cpu'))

            # neq_load_customized
            model = neq_load_customized(model, checkpoint['state_dict'])
            print("=> loaded pretrained checkpoint '{}' (epoch {})"
                  .format(args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrain))

    ### load transform & dataset ###
    transform_train = get_transform(args, mode = 'train')
    transform_val = get_transform(args, mode = 'val')

    train_loader = get_data(args, transform_train, 'train')
    val_loader = get_data(args, transform_val, 'val')

    # setup tools

    # Initialize denormalize transform to display input images via tensorboard
    de_normalize = denorm()

    # Get paths
    global img_path
    img_path, model_path, divers_path, pm_cache = set_path(args)

    global writer_train

    if not(args.test_diversity):

        # Train & validate for multiple epochs
        # book-keeping
        writer_val_enc = SummaryWriter(os.path.join(img_path, 'val_enc'))
        writer_val_noenc = SummaryWriter(os.path.join(img_path, 'val_noenc'))
        writer_val_pm_enc = SummaryWriter(os.path.join(img_path, 'val_pm_enc'))
        writer_val_pm_noenc = SummaryWriter(os.path.join(img_path, 'val_pm_noenc'))
        writer_train = SummaryWriter(os.path.join(img_path, 'train'))
        global cur_vae_kl_weight, cur_pred_divers_wt

        ### main loop ###
        for epoch in range(args.start_epoch, args.epochs):
            # Initially train without latent space if specified
            if epoch < args.vae_encoderless_epochs:
                train_with_latent = False
                cur_pred_divers_wt = 0.0 # encouraging diversity doesn't apply right now
            else:
                train_with_latent = True
                cur_pred_divers_wt = args.pred_divers_wt # hard transition
            print('Encoder enabled for training this epoch:', train_with_latent)

            # Then train for a while with beta = 0 (not explicitly coded)

            # Beta warmup = VAE KL loss weight adjustment over time
            if train_with_latent:
                epoch_post = epoch - (args.vae_encoderless_epochs + args.vae_inter_kl_epochs)
                if epoch_post < 0:
                    cur_vae_kl_weight = 0.0
                elif 0 <= epoch_post and epoch_post < args.vae_kl_beta_warmup:
                    cur_vae_kl_weight = args.vae_kl_weight * \
                        (1.0 - np.cos(epoch_post / args.vae_kl_beta_warmup * np.pi)) / 2
                else:
                    cur_vae_kl_weight = args.vae_kl_weight
            else:
                cur_vae_kl_weight = 0.0
            print('Current VAE KL loss term weight:', cur_vae_kl_weight)

            if epoch < args.pm_start:
                print('\n\nTraining with randomly sampled sequences from %s' %
                      args.dataset)
                train_loss_dpc, train_loss_kl, train_loss_divers, train_loss, train_acc, train_accuracy_list = train(
                    train_loader, model, optimizer, epoch, criterion, writer_train, args, de_normalize,
                    cur_vae_kl_weight, cur_pred_divers_wt, do_encode=train_with_latent)

            else:

                # change torch multithreading setting to circumvent 'open files' limit
                torch.multiprocessing.set_sharing_strategy('file_system')
                
                print('\n\nTraining with matched sequences (nearest neighbours) from %s\n\n' % args.dataset)
                if epoch % args.pm_freq == 0:
                    print(' ######################################## \n Matching sequences in embedding space... \n ########################################\n\n')
                    
                    # present matching and save results for future training
                    present_matching(args, model, pm_cache, epoch, 'train')
                    present_matching(args, model, pm_cache, epoch, 'val')
    
                
                args_pm = copy.copy(args)
                args_pm.dataset = 'epic_present_matching'
            
                pm_path = os.path.join(pm_cache, 'epoch_train_%d' % (
                    epoch - epoch % args_pm.pm_freq))  # retrieve last present matching results
                train_loader_pm = get_data(
                    args_pm, transform_train, mode='train', pm_path=pm_path, epoch=epoch)
                train_loss_dpc, train_loss_kl, train_loss_divers, train_loss, train_acc, train_accuracy_list = train(
                    train_loader_pm, model, optimizer, epoch, criterion, writer_train, args, de_normalize,
                    cur_vae_kl_weight, cur_pred_divers_wt, do_encode=train_with_latent)
                del train_loader_pm

                pm_path = os.path.join(pm_cache, 'epoch_val_%d' % (
                    epoch - epoch % args_pm.pm_freq))  # retrieve last present matching results
                val_loader_pm = get_data(
                    args_pm, transform_val, mode='val', pm_path=pm_path, epoch=epoch)
                
                val_loss_dpc_pm_enc, val_loss_kl_pm_enc, val_loss_divers_pm_enc, \
                    val_loss_pm_enc, val_acc_pm_enc, val_acc_pm_enc_list = validate(
                    val_loader_pm, model, epoch, args, criterion, cur_vae_kl_weight,
                    cur_pred_divers_wt, do_encode=train_with_latent,
                    select_best_among=1 if train_with_latent else args.select_best_among) # only one path when encoding
                
                val_loss_dpc_pm_noenc, val_loss_kl_pm_noenc, val_loss_divers_pm_noenc, \
                    val_loss_pm_noenc, val_acc_pm_noenc, val_acc_pm_noenc_list = validate(
                    val_loader_pm, model, epoch, args, criterion, cur_vae_kl_weight,
                    cur_pred_divers_wt, do_encode=False, select_best_among=args.select_best_among)
                del val_loader_pm

            val_loss_dpc_enc, val_loss_kl_enc, val_loss_divers_enc, \
                val_loss_enc, val_acc_enc, val_acc_enc_list = validate(
                val_loader, model, epoch, args, criterion, cur_vae_kl_weight,
                cur_pred_divers_wt, do_encode=train_with_latent,
                select_best_among=1 if train_with_latent else args.select_best_among) # only one path when encoding

            val_loss_dpc_noenc, val_loss_kl_noenc, val_loss_divers_noenc, \
                val_loss_noenc, val_acc_noenc, val_acc_noenc_list = validate(
                val_loader, model, epoch, args, criterion, cur_vae_kl_weight,
                cur_pred_divers_wt, do_encode=False, select_best_among=args.select_best_among)
            
            # Train curves
            writer_train.add_scalar('global/loss_dpc', train_loss_dpc, epoch)
            writer_train.add_scalar('global/loss_vae_kl', train_loss_kl, epoch)
            writer_train.add_scalar('global/loss_vae_divers', train_loss_divers, epoch)
            writer_train.add_scalar('global/loss', train_loss, epoch)
            writer_train.add_scalar('global/vae_kl_weight', cur_vae_kl_weight, epoch)
            writer_train.add_scalar('global/accuracy', train_acc, epoch)

            # Val curves
            writer_val_enc.add_scalar('global/loss_dpc', val_loss_dpc_enc, epoch)
            writer_val_enc.add_scalar('global/loss_vae_kl', val_loss_kl_enc, epoch)
            writer_val_enc.add_scalar('global/loss_vae_divers', val_loss_divers_enc, epoch)
            writer_val_enc.add_scalar('global/loss', val_loss_enc, epoch)
            writer_val_enc.add_scalar('global/accuracy', val_acc_enc, epoch)
            writer_val_noenc.add_scalar('global/loss_dpc', val_loss_dpc_noenc, epoch)
            writer_val_noenc.add_scalar('global/loss_vae_kl', val_loss_kl_noenc, epoch)
            writer_val_noenc.add_scalar('global/loss_vae_divers', val_loss_divers_noenc, epoch)
            writer_val_noenc.add_scalar('global/loss', val_loss_noenc, epoch)
            writer_val_noenc.add_scalar('global/accuracy', val_acc_noenc, epoch)

            if epoch >= args.pm_start:
                # Add present matching curves
                writer_val_pm_enc.add_scalar('global/loss_dpc', val_loss_dpc_pm_enc, epoch)
                writer_val_pm_enc.add_scalar('global/loss_vae_kl', val_loss_kl_pm_enc, epoch)
                writer_val_pm_enc.add_scalar('global/loss_vae_divers', val_loss_divers_pm_enc, epoch)
                writer_val_pm_enc.add_scalar('global/loss', val_loss_pm_enc, epoch)
                writer_val_pm_enc.add_scalar('global/accuracy', val_acc_pm_enc, epoch)
                writer_val_pm_noenc.add_scalar('global/loss_dpc', val_loss_dpc_pm_noenc, epoch)
                writer_val_pm_noenc.add_scalar('global/loss_vae_kl', val_loss_kl_pm_noenc, epoch)
                writer_val_pm_noenc.add_scalar('global/loss_vae_divers', val_loss_divers_pm_noenc, epoch)
                writer_val_pm_noenc.add_scalar('global/loss', val_loss_pm_noenc, epoch)
                writer_val_pm_noenc.add_scalar('global/accuracy', val_acc_pm_noenc, epoch)

            # Train accuracies
            writer_train.add_scalar('accuracy/top1', train_accuracy_list[0], epoch)
            writer_train.add_scalar('accuracy/top3', train_accuracy_list[1], epoch)
            writer_train.add_scalar('accuracy/top5', train_accuracy_list[2], epoch)

            # Val accuracies
            writer_val_noenc.add_scalar('accuracy/top1', val_acc_noenc_list[0], epoch)
            writer_val_noenc.add_scalar('accuracy/top3', val_acc_noenc_list[1], epoch)
            writer_val_noenc.add_scalar('accuracy/top5', val_acc_noenc_list[2], epoch)
            writer_val_enc.add_scalar('accuracy/top1', val_acc_enc_list[0], epoch)
            writer_val_enc.add_scalar('accuracy/top3', val_acc_enc_list[1], epoch)
            writer_val_enc.add_scalar('accuracy/top5', val_acc_enc_list[2], epoch)
            
            if epoch >= args.pm_start:
                # Add present matching curves
                writer_val_pm_noenc.add_scalar('accuracy/top1', val_acc_pm_noenc_list[0], epoch)
                writer_val_pm_noenc.add_scalar('accuracy/top3', val_acc_pm_noenc_list[1], epoch)
                writer_val_pm_noenc.add_scalar('accuracy/top5', val_acc_pm_noenc_list[2], epoch)
                writer_val_pm_enc.add_scalar('accuracy/top1', val_acc_pm_enc_list[0], epoch)
                writer_val_pm_enc.add_scalar('accuracy/top3', val_acc_pm_enc_list[1], epoch)
                writer_val_pm_enc.add_scalar('accuracy/top5', val_acc_pm_enc_list[2], epoch)

            # save check_point (best accuracy measured without encoder)
            is_best = val_acc_noenc > best_acc
            best_acc = max(val_acc_noenc, best_acc)
            save_checkpoint({'epoch': epoch + 1,
                             'net': args.net,
                             'state_dict': model.state_dict(),
                             'best_acc': best_acc,
                             'optimizer': optimizer.state_dict(),
                             'iteration': iteration},
                            is_best, filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch + 1)), save_every=10)

        print('Training from ep %d to ep %d finished' %
              (args.start_epoch, args.epochs))

    else:

        # Uncertainty evaluation: full video & no color adjustments
        # NOTE: be careful with training augmentation to prevent train-test resolution / scaling discrepancy
        tf_diversity = transforms.Compose([
            RandomHorizontalFlip(consistent=True),
            CenterCrop(224),
            Scale(size=(args.img_dim, args.img_dim)),
            ToTensor(),
            Normalize()
        ])

        print('Measuring diversity of generated samples...')
        val_divers_loader = get_data(args, tf_diversity, 'val')
        results = measure_vae_uncertainty_loader(val_divers_loader, model, args.paths,
                                                 print_freq=args.print_freq, collect_actions=False,
                                                 force_context_dropout=args.force_context_dropout)
        cur_path = os.path.join(divers_path, 'epoch' + str(checkpoint['epoch']) + \
                   '_paths' + str(args.paths) + \
                   ('_ctdrop' if args.force_context_dropout else '') + '.p')
        with open(cur_path, 'wb') as f:
            pickle.dump(results, f)
        print('For future use, uncertainty evaluation results stored to ' + cur_path)