예제 #1
0
def evaluate(model,dataset,device):
	dataloader = DataLoader(dataset, batch_size=args.val_batch_size,shuffle=False, num_workers=4)
	loss_fn = nn.CrossEntropyLoss()
	total_loss = AverageMeter()
	total_acc = AverageMeter()
	# model.cuda()
#     optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)
#     optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
	steps = 0
	model.eval() # check this
	with torch.no_grad():
		for i_batch, batch in enumerate(dataloader):
			x = prepare_batch(batch['data'].to(device=device))
			y = batch['label'].to(device=device)
			lengths = batch['lengths']
			# x,x_len,y = prepare_batch(x,y)
			# y = torch.autograd.Variable(y).long()
	#         optim.zero_grad()
			# model.zero_grad()
			y_hat = model(x,lengths)
			loss = loss_fn(y_hat,y)
			num_currect = (torch.max(y_hat, 1)[1].view(y.size()).data == y.data).float().sum()
			acc = 100.0 * num_currect/args.val_batch_size
			# loss.backward()
			# optim.step()
			
			steps += 1
			total_loss.update(loss.item())
			total_acc.update(acc.item())
		
		# if steps % args.log_interval == 0:
			# print (f'Epoch: {epoch}, batch: {steps}, Training Loss: {total_epoch_loss/steps:.4f}, Training Accuracy: {total_epoch_acc/steps: .2f}%')
		return total_loss, total_acc
예제 #2
0
def train_epoch(model, elmo, dataset, device, lr, epoch=1):
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)
    loss_fn = nn.CrossEntropyLoss()
    total_epoch_loss = AverageMeter()
    total_epoch_acc = AverageMeter()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    # model.cuda()
    #     optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    #     optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    steps = 0
    model.train()
    end = time.time()
    for i_batch, batch in enumerate(dataloader):
        # print(batch['data'].shape)
        # x,y,lengths = prepare_batch(batch,device)
        x = prepare_batch(batch['data']).to(device=device)
        data_time.update(time.time() - end)
        x = elmo(x)['elmo_representations'][0]
        x = x.transpose(0, 1)
        # print(x.shape)
        y = batch['label'].to(device=device)
        lengths = batch['lengths']
        data_time.update(time.time() - end)
        model.zero_grad()
        y_hat = model(x, lengths)
        loss = loss_fn(y_hat, y)
        num_currect = (torch.max(y_hat, 1)[1].view(
            y.size()).data == y.data).float().sum()
        acc = 100.0 * num_currect / args.batch_size
        loss.backward()
        optim.step()

        steps += 1
        total_epoch_loss.update(loss.item())
        total_epoch_acc.update(acc.item())
        batch_time.update(time.time() - end)
        if steps % args.log_interval == 0:
            print(f'Epoch: {epoch}, batch: {steps}' +
                  f', average data time: {data_time.average():.3f}' +
                  f', average batch time: {batch_time.average():.3f}' +
                  f', Training Loss: {total_epoch_loss.average():.4f}' +
                  f', Training Accuracy: {total_epoch_acc.average(): .2f}%')

        end = time.time()
    return total_epoch_loss, total_epoch_acc
예제 #3
0
def train(model, elmo, train_dataset, device='cuda', val_dataset=None):
    lr = args.lr
    train_log = {'acc': AverageMeter(), 'loss': AverageMeter()}
    validation_log = {'acc': AverageMeter(), 'loss': AverageMeter()}
    try:
        for epoch in range(1, args.epochs + 1):
            print('Training with learning rate of: ' + f'{lr:1.6f}')
            start_time = time.time()
            train_loss, train_acc = train_epoch(model,
                                                elmo,
                                                train_dataset,
                                                epoch=epoch,
                                                device=device,
                                                lr=lr)
            train_log['acc'].update(train_acc.average())
            train_log['loss'].update(train_loss.average())

            print('-' * 89)
            print(
                '| end of epoch {:3d} | time: {:5.2f}s | train loss {:5.2f} | '
                'train accuracy {:8.2f}'.format(epoch,
                                                (time.time() - start_time),
                                                train_loss.average(),
                                                train_acc.average()))
            if val_dataset is not None:
                start_time = time.time()
                val_loss, val_acc = evaluate(model, elmo, val_dataset, device)
                print(
                    '| end of epoch {:3d} | time: {:5.2f}s | validation loss {:5.2f} | '
                    'validation accuracy {:8.2f}'.format(
                        epoch, (time.time() - start_time), val_loss.average(),
                        val_acc.average()))
                if epoch > 1:
                    if validation_log['loss'].vals[-1] < val_loss.average():
                        lr = lr * args.lr_decay
                validation_log['acc'].update(val_acc.average())
                validation_log['loss'].update(val_loss.average())
            print('-' * 89)
            # Save the model if the validation loss is the best we've seen so far.
            # if not best_val_loss or val_loss < best_val_loss:
            # 	# with open(save, 'wb') as f:
            # 	#     torch.save(model, f)
            # 	best_val_loss = val_loss
            # else:
            # 	# Anneal the learning rate if no improvement has been seen in the validation dataset.
            # 	lr /= 4.0
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')
예제 #4
0
    def userConfig(self):
        """
        include the task specific setup here
        """
        if self.args.featurelearning:
            assert('f' in self.args.outputType)
        pointer = 0
        if 'rgb' in self.args.outputType:
            self.args.idx_rgb_start = pointer
            self.args.idx_rgb_end = pointer + 3
            pointer += 3
        if 'n' in self.args.outputType:
            self.args.idx_n_start = pointer
            self.args.idx_n_end   = pointer + 3
            pointer += 3
        if 'd' in self.args.outputType:
            self.args.idx_d_start = pointer
            self.args.idx_d_end   = pointer + 1
            pointer += 1
        if 'k' in self.args.outputType:
            self.args.idx_k_start = pointer
            self.args.idx_k_end   = pointer + 1
            pointer += 1
        if 's' in self.args.outputType:
            self.args.idx_s_start = pointer
            self.args.idx_s_end   = pointer + self.args.snumclass # 21 class
            pointer += self.args.snumclass
        if 'f' in self.args.outputType:
            self.args.idx_f_start = pointer
            self.args.idx_f_end   = pointer + self.args.featureDim
            pointer += self.args.featureDim

        self.args.num_output = pointer
        self.args.num_input = 8*2
        self.args.ngpu = int(1)
        self.args.nz = int(100)
        self.args.ngf = int(64)
        self.args.ndf = int(64)
        self.args.nef = int(64)
        self.args.nBottleneck = int(4000)
        self.args.wt_recon = float(0.998)
        self.args.wtlD = float(0.002)
        self.args.overlapL2Weight = 10

        # setup logger
        self.tensorboardX = SummaryWriter(log_dir=os.path.join(self.args.EXP_DIR, 'tensorboard'))
        self.logger = log.logging(self.args.EXP_DIR_LOG)
        self.logger_errG      = AverageMeter()
        self.logger_errG_recon   = AverageMeter()
        self.logger_errG_rgb  = AverageMeter()
        self.logger_errG_d    = AverageMeter()
        self.logger_errG_n    = AverageMeter()
        self.logger_errG_s    = AverageMeter()
        self.logger_errG_k    = AverageMeter()
        
        self.logger_errD_fake = AverageMeter()
        self.logger_errD_real = AverageMeter()
        self.logger_errG_fl   =   AverageMeter()
        self.logger_errG_fl_pos =   AverageMeter()
        self.logger_errG_fl_neg =   AverageMeter()
        self.logger_errG_fl_f =   AverageMeter()
        self.logger_errG_fc   =   AverageMeter()
        self.logger_errG_pn   =   AverageMeter()
        self.logger_errG_freq =   AverageMeter()

        self.global_step=0
        self.speed_benchmark=True
        if self.speed_benchmark:
            self.time_per_step=AverageMeter()

        self.sift = cv2.xfeatures2d.SIFT_create()
        self.evalFeatRatioDL_obs,self.evalFeatRatioDL_unobs=[],[]
        self.evalFeatRatioDLc_obs,self.evalFeatRatioDLc_unobs=[],[]
        self.evalFeatRatioSift=[]
        self.evalErrN=[]
        self.evalErrD=[]
        self.evalSemantic    = []
        self.evalSemantic_gt = []

        self.sancheck={}

        # semantic encoding
        if 'scannet' in self.args.dataList:
            self.colors = config.scannet_color_palette
        elif 'matterport' in self.args.dataList:
            self.colors = config.matterport_color_palette
        elif 'suncg' in self.args.dataList:
            self.colors = config.suncg_color_palette

        self.class_balance_weights = torch_op.v(np.ones([self.args.snumclass]))