def train(model, task, y_list, x_list, checkpoint_dir, checkpoint_prefix, device, batch_size=512, max_seq_len=100, lr=1e-3, resume_surfix=None, logger=None):
	"""
	: model - torch.nn.module: model to be trained
	: task - list[tuple(int,list[int])]: epoch + file to train
	: y_list - list[str]: list of y variables
	: x_list - list[str]: list of x variables to generate embed sequence for
	: checkpoint_dir - str: path to checkpoint directory
	: checkpoint_prefix - str: prefix of checkpoint file
	: device - torch.device: device to train the model
	: batch_size - int: size of mini batch
	: max_seq_len - int: max length for sequence input, default 100 
	: lr - float: learning rate for Adam, default 1e-3
	: resume_surfix - str: model to reload if not training from scratch
	"""
	global input_split_path, embed_path
	if not gc.isenabled(): gc.enable()

	# Check checkpoint directory
	if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)

	# Calculate number of batch
	div, mod = divmod(90000, batch_size)
	batch_per_file = div + min(1, mod)
	batch_per_epoch = 9 * batch_per_file

	# Load model if not train from scratch
	last_step = -1
	loss_fn = nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam([{'params':model.parameters(), 'initial_lr':1}], betas=(0.9, 0.98), eps=1e-9, amsgrad=True)

	if resume_surfix is not None:
		model_artifact_path = os.path.join(checkpoint_dir, '{}_{}.pth'.format(checkpoint_prefix, resume_surfix))
		model.load_state_dict(torch.load(model_artifact_path))
		if logger: logger.info('Model loaded from {}'.format(model_artifact_path))
		optimizer_artifact_path = os.path.join(checkpoint_dir, '{}_{}_opti.pth'.format(checkpoint_prefix, resume_surfix))
		if logger: logger.info('Model loaded from {}'.format(optimizer_artifact_path))

		t = resume_surfix.split('_')
		ep, fi = int(t[0]), int(t[1])
		last_step = (ep-1)*batch_per_epoch+fi*batch_per_file-1
		if logger: logger.info('Learning rate resumed from step {}'.format(last_step+1))

	scheduler = get_transformer_scheduler(optimizer, 512, 1000, last_step=last_step)
	model.to(device)
	
	# Initiate word vector host
	wv = wv_loader_v2(x_list, embed_path, max_seq_len=max_seq_len)
	if logger: logger.info('Word vector host ready')
	
	# Main Loop
	for epoch, file_idx_list in task:
		if logger:
			logger.info('=========================')
			logger.info('Processing Epoch {}/{}'.format(epoch, task[-1][0]))
			logger.info('=========================')

		# Train model
		model.train()
		train_running_loss, train_n_batch = 0, 0

		for index, split_idx in enumerate(file_idx_list, start=1):
			dl = data_loader_v2(wv, y_list, x_list, input_split_path, split_idx, batch_size=batch_size, shuffle=True)
			it = iter(dl)
			while True:
				try:
					yl, xl, x_seq_len = next(it)
					y = yl[0].to(device)
					x = torch.cat(xl, dim=2).to(device)

					optimizer.zero_grad()
					yp = F.softmax(model(x, x_seq_len), dim=1)
					loss = loss_fn(yp,y)

					loss.backward()
					torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
					optimizer.step()

					train_running_loss += loss.item()
					train_n_batch += 1

					scheduler.step()

				except StopIteration:
					break

				except Exception as e:
					if logger: logger.error(e)
					return 

			del dl, it
			_ = gc.collect()

			if logger:
				logger.info('Epoch {}/{} - File {}/8 Done - Train Loss: {:.6f}, Learning Rate {:.7f}'.format(epoch, task[-1][0], start, train_running_loss/train_n_batch, optimizer.param_groups[0]['lr']))

			# Save model & optimizer state dict
			ck_file_name = '{}_{}_{}.pth'.format(checkpoint_prefix, epoch, split_idx)
			ck_file_path = os.path.join(checkpoint_dir, ck_file_name)
			torch.save(model.state_dict(), ck_file_path)
			op_file_name = '{}_{}_{}_opti.pth'.format(checkpoint_prefix, epoch, split_idx)
			op_file_path = os.path.join(checkpoint_dir, op_file_name)
			torch.save(optimizer.state_dict(), op_file_path)

		torch.cuda.empty_cache()

		# Evaluate model
		model.eval()
		test_running_loss, test_n_batch = 0, 0
		true_y, pred_y = [], []

		with torch.no_grad():
			for split_idx in [9, 10]:
				dl = data_loader_v2(wv, y_list, x_list, input_split_path, split_idx, batch_size=batch_size, shuffle=True)
				it = iter(dl)
				while True:
					try:
						yl, xl, x_seq_len = next(it)
						y = yl[0].to(device)
						x = torch.cat(xl, dim=2).to(device)
						yp = F.softmax(model(x, x_seq_len), dim=1)
						loss = loss_fn(yp,y)

						pred_y.extend(list(yp.cpu().detach().numpy()))
						true_y.extend(list(y.cpu().detach().numpy()))

						test_running_loss += loss.item()
						test_n_batch += 1

					except StopIteration:
						break

					except Exception as e:
						if logger: logger.error(e)
						return 

				del dl, it
				_ = gc.collect()

		pred = np.argmax(np.array(pred_y), 1)
		true = np.array(true_y).reshape((-1,))
		acc_score = accuracy_score(true, pred)

		del pred, true, pred_y, true_y
		_ = gc.collect()

		if logger:
			logger.info('Epoch {}/{} Done - Test Loss: {:.6f}, Test Accuracy: {:.6f}'.format(epoch, task[-1][0], test_running_loss/test_n_batch, acc_score))
def train(model,
          task,
          y_list,
          x_list,
          checkpoint_dir,
          checkpoint_prefix,
          device,
          batch_size=512,
          max_seq_len=100,
          lr=1e-3,
          resume_surfix=None,
          logger=None):
    """
	: model - torch.nn.module: model to be trained
	: task - list[tuple(int,list[int])]: epoch + file to train
	: y_list - list[str]: list of y variables
	: x_list - list[str]: list of x variables to generate embed sequence for
	: checkpoint_dir - str: path to checkpoint directory
	: checkpoint_prefix - str: prefix of checkpoint file
	: device - torch.device: device to train the model
	: batch_size - int: size of mini batch
	: max_seq_len - int: max length for sequence input, default 100 
	: lr - float: learning rate for Adam, default 1e-3
	: resume_surfix - str: model to reload if not training from scratch
	"""
    global input_split_path, embed_path
    if not gc.isenabled(): gc.enable

    # Check checkpoint directory
    if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)

    # Initiate word vector host
    wv = wv_loader_v2(x_list, embed_path, max_seq_len=max_seq_len)
    if logger: logger.info('Word vector host ready')

    # Load model if not train from scratch
    if resume_surfix is not None:
        model_artifact_path = os.path.join(
            checkpoint_dir, '{}_{}.pth'.format(checkpoint_prefix,
                                               resume_surfix))
        model.load_state_dict(torch.load(model_artifact_path))
        if logger:
            logger.info('Model loaded from {}'.format(model_artifact_path))

    # Set up loss function and optimizer
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=0,
        threshold=1e-5,
        threshold_mode='abs')

    # Main Loop
    for epoch, file_idx_list in task:
        if logger:
            logger.info('=========================')
            logger.info('Processing Epoch {}/{}'.format(epoch, task[-1][0]))
            logger.info('=========================')

        # Train model
        model.train()
        train_age_loss, train_gender_loss, train_n_batch = 0, 0, 0

        for split_idx in file_idx_list:
            dl = data_loader_v2(wv,
                                y_list,
                                x_list,
                                input_split_path,
                                split_idx,
                                batch_size=batch_size,
                                shuffle=True)
            it = iter(dl)
            while True:
                try:
                    yl, xl, x_seq_len = next(it)
                    y_age = yl[0].to(device)
                    y_gender = yl[1].to(device)
                    x = [i.to(device) for i in xl] + [x_seq_len]

                    optimizer.zero_grad()
                    yp = F.softmax(model(*x), dim=1)
                    l_age = loss_fn(yp[0], y_age)
                    l_gender = loss_fn(yp[1], y_gender)
                    loss = l_age + l_gender

                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_norm=100)
                    optimizer.step()

                    train_age_loss += l_age.item()
                    train_gender_loss += l_gender.item()
                    train_n_batch += 1

                except StopIteration:
                    break

                except Exception as e:
                    if logger: logger.error(e)
                    return

            del dl, it
            _ = gc.collect()

            if logger:
                logger.info(
                    'Epoch {}/{} - File {}/9 Done - Train Loss - Age: {:.6f}, Gender: {:.6f}'
                    .format(epoch, task[-1][0], split_idx,
                            train_age_loss / train_n_batch,
                            train_gender_loss / train_n_batch))

            # Save model state dict
            ck_file_name = '{}_{}_{}.pth'.format(checkpoint_prefix, epoch,
                                                 split_idx)
            ck_file_path = os.path.join(checkpoint_dir, ck_file_name)

            torch.save(model.state_dict(), ck_file_path)

        torch.cuda.empty_cache()

        # Evaluate model
        model.eval()
        test_age_loss, test_gender_loss, test_n_batch = 0, 0, 0
        true_age, pred_age, true_gender, pred_gender = [], [], [], []

        with torch.no_grad():
            dl = data_loader_v2(wv,
                                y_list,
                                x_list,
                                input_split_path,
                                10,
                                batch_size=batch_size,
                                shuffle=True)
            it = iter(dl)
            while True:
                try:
                    yl, xl, x_seq_len = next(it)
                    y_age = yl[0].to(device)
                    y_gender = yl[1].to(device)
                    x = [i.to(device) for i in xl] + [x_seq_len - 1]
                    yp = F.softmax(model(*x), dim=1)
                    l_age = loss_fn(yp[0], y_age)
                    l_gender = loss_fn(yp[1], y_gender)
                    loss = l_age + l_gender

                    pred_age.extend(list(yp[0].cpu().detach().numpy()))
                    true_age.extend(list(y_age.cpu().detach().numpy()))
                    pred_gender.extend(list(yp[1].cpu().detach().numpy()))
                    true_gender.extend(list(y_gender.cpu().detach().numpy()))

                    test_age_loss += l_age.item()
                    test_gender_loss += l_gender.item()
                    test_n_batch += 1

                except StopIteration:
                    break

                except Exception as e:
                    if logger: logger.error(e)
                    return

            del dl, it
            _ = gc.collect()

        pred_age = np.argmax(np.array(pred_age), 1)
        true_age = np.array(true_age).reshape((-1, ))
        acc_age = accuracy_score(true_age, pred_age)

        pred_gender = np.argmax(np.array(pred_gender), 1)
        true_gender = np.array(true_gender).reshape((-1, ))
        acc_gender = accuracy_score(true_gender, pred_gender)

        del pred_age, true_age, pred_gender, true_gender
        _ = gc.collect()

        if logger:
            logger.info(
                'Epoch {}/{} Done - Age Loss: {:.6f}, Age Accuracy: {:.6f}'.
                format(epoch, task[-1][0], test_age_loss / test_n_batch,
                       acc_age))
            logger.info(
                'Epoch {}/{} Done - Gender Loss: {:.6f}, Gender Accuracy: {:.6f}'
                .format(epoch, task[-1][0], test_gender_loss / test_n_batch,
                        acc_gender))

        scheduler.step(acc_gender + acc_age)
        if logger:
            logger.info('Epoch {}/{} - Updated Learning Rate: {:.8f}'.format(
                epoch, task[-1][0], optimizer.param_groups[0]['lr']))
Beispiel #3
0
def pred(model,
         x_list,
         checkpoint_dir,
         checkpoint_prefix,
         output_dir,
         output_prefix,
         device,
         load_surfix,
         batch_size=512,
         max_seq_len=100,
         logger=None):
    """
	: model - torch.nn.module: model to be trained
	: x_list - list[str]: list of x variables to generate embed sequence for
	: checkpoint_dir - str: path to checkpoint directory
	: checkpoint_prefix - str: prefix of checkpoint file
	: output_dir - str: path to output directory
	: output_prefix - str: prefix of output file
	: device - torch.device: device to train the model
	: load_surfix- - str: model artifact to load
	: batch_size - int: size of mini batch
	: max_seq_len - int: max length for sequence input, default 100
	"""
    global input_split_path, embed_path
    if not gc.isenabled(): gc.enable

    # Initiate word vector host
    wv = wv_loader_v2(x_list, embed_path, max_seq_len=max_seq_len)
    if logger: logger.info('Word vector host ready')

    # Load model
    model_artifact_path = os.path.join(
        checkpoint_dir, '{}_{}.pth'.format(checkpoint_prefix, load_surfix))
    model.load_state_dict(torch.load(model_artifact_path))
    if logger: logger.info('Model loaded from {}'.format(model_artifact_path))
    model.to(device)
    model.eval()

    # Main Loop
    pred_y = []

    for file_idx in np.arange(1, 11):
        with torch.no_grad():
            dl = data_loader_v2(wv, [],
                                x_list,
                                input_split_path,
                                file_idx,
                                batch_size=batch_size,
                                shuffle=False,
                                train=False)
            it = iter(dl)
            while True:
                try:
                    _, xl, x_seq_len = next(it)
                    x = [i.to(device) for i in xl] + [x_seq_len - 1]
                    yp = F.softmax(model(*x), dim=1)
                    pred_y.extend(list(yp.cpu().detach().numpy()))

                except StopIteration:
                    break

                except Exception as e:
                    if logger: logger.error(e)
                    return

            del dl, it
            _ = gc.collect()

        if logger:
            logger.info('File {}/10 done with {}'.format(
                file_idx, model_artifact_path))

    pred = np.array(pred_y)

    save_path = os.path.join(output_dir,
                             '{}_{}.npy'.format(output_prefix, load_surfix))
    with open(save_path, 'wb') as f:
        np.save(f, pred)

    if logger:
        logger.info('Prediction result is saved to {}'.format(save_path))
Beispiel #4
0
def train(model,
          task,
          y_list,
          x_list,
          checkpoint_dir,
          checkpoint_prefix,
          device,
          batch_size=512,
          max_seq_len=100,
          lr=1e-3,
          resume_surfix=None,
          logger=None):
    """
	: model - torch.nn.module: model to be trained
	: task - list[tuple(int,list[int])]: epoch + file to train
	: y_list - list[str]: list of y variables
	: x_list - list[str]: list of x variables to generate embed sequence for
	: checkpoint_dir - str: path to checkpoint directory
	: checkpoint_prefix - str: prefix of checkpoint file
	: device - torch.device: device to train the model
	: batch_size - int: size of mini batch
	: max_seq_len - int: max length for sequence input, default 100 
	: lr - float: learning rate for Adam, default 1e-3
	: resume_surfix - str: model to reload if not training from scratch
	"""
    global input_split_path, embed_path
    if not gc.isenabled(): gc.enable

    # Check checkpoint directory
    if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)

    # Initiate word vector host
    wv = wv_loader_v2(x_list, embed_path, max_seq_len=max_seq_len)
    if logger: logger.info('Word vector host ready')

    # Load model if not train from scratch
    if resume_surfix is not None:
        model_artifact_path = os.path.join(
            checkpoint_dir, '{}_{}.pth'.format(checkpoint_prefix,
                                               resume_surfix))
        model.load_state_dict(torch.load(model_artifact_path))
        if logger:
            logger.info('Model loaded from {}'.format(model_artifact_path))

    # Set up loss function and optimizer
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    div, mod = divmod(90000, batch_size)
    batch_per_file = div + min(mod, 1)
    total_steps = batch_per_file * 9 * task[-1][0]

    start_epoch = task[0][0]
    start_file = task[0][1][0]
    last_batch = (start_epoch - 1) * batch_per_file * 9 + (
        start_file - 1) * batch_per_file - 1

    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                    max_lr=5e-3,
                                                    total_steps=total_steps,
                                                    last_epoch=last_batch)

    # Main Loop
    for epoch, file_idx_list in task:
        if logger:
            logger.info('=========================')
            logger.info('Processing Epoch {}/{}'.format(epoch, task[-1][0]))
            logger.info('=========================')

        # Train model
        model.train()
        train_running_loss, train_n_batch = 0, 0

        for split_idx in file_idx_list:
            dl = data_loader_v2(wv,
                                y_list,
                                x_list,
                                input_split_path,
                                split_idx,
                                batch_size=batch_size,
                                shuffle=True)
            it = iter(dl)
            while True:
                try:
                    yl, xl, x_seq_len = next(it)
                    y = yl[0].to(device)
                    x = [i.to(device) for i in xl] + [x_seq_len]

                    optimizer.zero_grad()
                    yp = F.softmax(model(*x), dim=1)
                    loss = loss_fn(yp, y)

                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_norm=100)
                    optimizer.step()

                    train_running_loss += loss.item()
                    train_n_batch += 1

                    scheduler.step()

                except StopIteration:
                    break

                except Exception as e:
                    if logger: logger.error(e)
                    return

            del dl, it
            _ = gc.collect()

            if logger:
                logger.info(
                    'Epoch {}/{} - File {}/9 Done - Train Loss: {:.6f}'.format(
                        epoch, task[-1][0], split_idx,
                        train_running_loss / train_n_batch))

            # Save model state dict
            ck_file_name = '{}_{}_{}.pth'.format(checkpoint_prefix, epoch,
                                                 split_idx)
            ck_file_path = os.path.join(checkpoint_dir, ck_file_name)

            torch.save(model.state_dict(), ck_file_path)

        torch.cuda.empty_cache()

        # Evaluate model
        model.eval()
        test_running_loss, test_n_batch = 0, 0
        true_y, pred_y = [], []

        with torch.no_grad():
            dl = data_loader_v2(wv,
                                y_list,
                                x_list,
                                input_split_path,
                                10,
                                batch_size=batch_size,
                                shuffle=True)
            it = iter(dl)
            while True:
                try:
                    yl, xl, x_seq_len = next(it)
                    y = yl[0].to(device)
                    x = [i.to(device) for i in xl] + [x_seq_len]
                    yp = F.softmax(model(*x), dim=1)
                    loss = loss_fn(yp, y)

                    pred_y.extend(list(yp.cpu().detach().numpy()))
                    true_y.extend(list(y.cpu().detach().numpy()))

                    test_running_loss += loss.item()
                    test_n_batch += 1

                except StopIteration:
                    break

                except Exception as e:
                    if logger: logger.error(e)
                    return

            del dl, it
            _ = gc.collect()

        pred = np.argmax(np.array(pred_y), 1)
        true = np.array(true_y).reshape((-1, ))
        acc_score = accuracy_score(true, pred)

        del pred, true, pred_y, true_y
        _ = gc.collect()

        if logger:
            logger.info(
                'Epoch {}/{} Done - Test Loss: {:.6f}, Test Accuracy: {:.6f}'.
                format(epoch, task[-1][0], test_running_loss / test_n_batch,
                       acc_score))
def train(model, task, y_list, x_list, checkpoint_dir, checkpoint_prefix, device, batch_size=512, max_seq_len=100, lr=1e-3, resume_surfix=None, logger=None):
	"""
	: model - torch.nn.module: model to be trained
	: task - list[tuple(int,list[int])]: epoch + file to train
	: y_list - list[str]: list of y variables
	: x_list - list[str]: list of x variables to generate embed sequence for
	: checkpoint_dir - str: path to checkpoint directory
	: checkpoint_prefix - str: prefix of checkpoint file
	: device - torch.device: device to train the model
	: batch_size - int: size of mini batch
	: max_seq_len - int: max length for sequence input, default 100 
	: lr - float: learning rate for Adam, default 1e-3
	: resume_surfix - str: model to reload if not training from scratch
	"""
	global input_split_path, embed_path
	if not gc.isenabled(): gc.enable()

	# Check checkpoint directory
	if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)

	# Calculate number of batch
	div, mod = divmod(90000, batch_size)
	batch_per_file = div + min(1, mod)
	batch_per_epoch = 9 * batch_per_file

	# Load model if not train from scratch
	loss_fn = nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True)
	scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=0, threshold=1e-5, threshold_mode='abs')

	if resume_surfix is not None:
		model_artifact_path = os.path.join(checkpoint_dir, '{}_{}.pth'.format(checkpoint_prefix, resume_surfix))
		model.load_state_dict(torch.load(model_artifact_path))
		if logger: logger.info('Model loaded from {}'.format(model_artifact_path))
		optimizer_artifact_path = os.path.join(checkpoint_dir, '{}_{}_opti.pth'.format(checkpoint_prefix, resume_surfix))
		if logger: logger.info('Optimizer loaded from {}'.format(optimizer_artifact_path))

	model.to(device)
	
	# Initiate word vector host
	wv = wv_loader_v2(x_list, embed_path, max_seq_len=max_seq_len)
	if logger: logger.info('Word vector host ready')
	
	# Main Loop
	for epoch, file_idx_list in task:
		if logger:
			logger.info('=========================')
			logger.info('Processing Epoch {}/{}'.format(epoch, task[-1][0]))
			logger.info('=========================')

		# Train model
		model.train()
		train_running_loss, train_n_batch = 0, 0

		for index, split_idx in enumerate(file_idx_list, start=1):
			dl = data_loader_v2(wv, y_list, x_list, input_split_path, split_idx, batch_size=batch_size, shuffle=True)
			it = iter(dl)
			while True:
				try:
					yl, xl, x_seq_len = next(it)
					y = torch.add(yl[0], yl[1], alpha=10).to(device)
					x = [i.to(device) for i in xl] + [x_seq_len-1]

					optimizer.zero_grad()
					yp = F.softmax(model(*x), dim=1)
					loss = loss_fn(yp,y)

					loss.backward()
					torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
					optimizer.step()

					train_running_loss += loss.item()
					train_n_batch += 1

				except StopIteration:
					break

				except Exception as e:
					if logger: logger.error(e)
					return 

			del dl, it
			_ = gc.collect()

			if logger:
				logger.info('Epoch {}/{} - File {}/8 Done - Train Loss: {:.6f}, Learning Rate {:.7f}'.format(epoch, task[-1][0], index, train_running_loss/train_n_batch, optimizer.param_groups[0]['lr']))

			# Save model & optimizer state dict
			ck_file_name = '{}_{}_{}.pth'.format(checkpoint_prefix, epoch, split_idx)
			ck_file_path = os.path.join(checkpoint_dir, ck_file_name)
			torch.save(model.state_dict(), ck_file_path)
			op_file_name = '{}_{}_{}_opti.pth'.format(checkpoint_prefix, epoch, split_idx)
			op_file_path = os.path.join(checkpoint_dir, op_file_name)
			torch.save(optimizer.state_dict(), op_file_path)

		torch.cuda.empty_cache()

		# Evaluate model
		model.eval()
		test_running_loss, test_n_batch = 0, 0
		true_y, pred_y = [], []

		with torch.no_grad():
			for split_idx in [9, 10]:
				dl = data_loader_v2(wv, y_list, x_list, input_split_path, split_idx, batch_size=batch_size, shuffle=True)
				it = iter(dl)
				while True:
					try:
						yl, xl, x_seq_len = next(it)
						y = torch.add(yl[0], yl[1], alpha=10).to(device)
						x = [i.to(device) for i in xl] + [x_seq_len-1]
						yp = F.softmax(model(*x), dim=1)
						loss = loss_fn(yp,y)

						pred_y.extend(list(yp.cpu().detach().numpy()))
						true_y.extend(list(y.cpu().detach().numpy()))

						test_running_loss += loss.item()
						test_n_batch += 1

					except StopIteration:
						break

					except Exception as e:
						if logger: logger.error(e)
						return 

				del dl, it
				_ = gc.collect()

		pred = np.argmax(np.array(pred_y), 1)
		true = np.array(true_y).reshape((-1,))
		age_acc = accuracy_score(true%10, pred%10)
		gen_acc = accuracy_score(true//10, pred//10)

		del pred, true, pred_y, true_y
		_ = gc.collect()

		if logger:
			logger.info('Epoch {}/{} Done - Test Loss: {:.6f}, Age Accuracy: {:.6f}, Gender Accuracy: {:.6f}, Combined Accuracy: {:.6f}'.format(
				epoch, task[-1][0], test_running_loss/test_n_batch, age_acc, gen_acc, age_acc+gen_acc))

		scheduler.step(test_running_loss/test_n_batch)
		if logger:
			logger.info('Epoch {}/{} - Updated Learning Rate: {:.8f}'.format(epoch, task[-1][0], optimizer.param_groups[0]['lr']))