コード例 #1
0
def run():
    global data
    data = dataset.Data()

    prefix = 'union_'
    fn_list = [
        f'{prefix}id_att_3',
        f'{prefix}id_last',
        f'{prefix}c_att_5',
        f'{prefix}c_last',
    ]

    print('fn list:')
    print(fn_list)
    print('alpha_rare:')
    print(alpha_rare)
    best_weight = run_fuse_vali(fn_list)
    print(write_ans(fn_list, weight_list=best_weight))
コード例 #2
0
def train(params):
    # tf.config.experimental_run_functions_eagerly(True)
    p = u.arg_parse(params)
    path = os.getcwd()
    p.logging.log_dir = path + p.logging.log_dir
    os.makedirs(p.logging.log_dir, exist_ok=True)
    p.dataset.data_dir = path + p.dataset.data_dir
    os.makedirs(p.dataset.data_dir, exist_ok=True)

    data = dataset.Data(p.dataset)
    # set initial bias and mean
    p.model.encoder.init_m = data.init_m
    p.model.decoder.init_bias = data.init_bias

    print('logging to {}\n'.format(p.logging.log_dir))

    model = models.VAE(p.model)

    num_steps = data.train_steps_per_epoch * p.training.num_epochs

    class LR(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, lr, num_steps):
            self.num_steps = num_steps
            self.lr = lr

        @tf.function
        def __call__(self, step):
            s = step / self.num_steps
            if s < 0.1:
                return self.lr * s / 0.1
            elif s < 0.9:
                return self.lr
            else:
                return self.lr * tf.pow(0.01, (s - 0.9) / 0.1)

    lr_schedule = LR(p.training.lr, num_steps)

    def kl_schedule(step):
        s = (1.0 * step) / num_steps
        if s < p.training.kl_anneal_portion:
            return s / p.training.kl_anneal_portion
        else:
            return 1.

    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    metric_mean = tf.keras.metrics.Mean('mean')
    summary_writer = tf.summary.create_file_writer(p.logging.log_dir)

    @tf.function
    def train_step(images, labels):
        images = u.sample_bernoulli(images)
        with tf.GradientTape() as tape:
            loss = model(images, training=True, iw_k=p.training.train_iw_k)
            loss_train = loss + tf.reduce_sum(model.losses)
        gradients = tape.gradient(loss_train, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    @tf.function
    def test_step(images, labels, iw_k=1):
        images = u.sample_bernoulli(images)
        loss_test = model(images, iw_k=iw_k)
        metric_mean(loss_test)

    @tf.function
    def analyze():
        z_sample = model.prior.sample(p.logging.num_samples_to_generate)
        _, x_logits = model.decoder(z_sample)
        return u.tile_image_tf(
            tf.reshape(tf.sigmoid(x_logits), [-1, 28, 28, 1]), 10, 10, 28, 28)

    checkpoint = tf.train.Checkpoint(model=model)
    ckpt_manager = tf.train.CheckpointManager(checkpoint,
                                              directory=p.logging.log_dir,
                                              max_to_keep=1)
    checkpoint.restore(ckpt_manager.latest_checkpoint)
    for epoch in range(p.training.num_epochs):
        t0 = time.time()
        for (images, labels) in data.data_train:
            train_step(images, labels)
        train_time = time.time() - t0
        print("Epoch: {}, Train time: {}".format(epoch + 1, train_time))

        with summary_writer.as_default():
            step = tf.cast(optimizer.iterations, tf.float32)
            tf.summary.scalar('lr', lr_schedule(step), epoch + 1)
            tf.summary.scalar('kl_coeff', kl_schedule(step), epoch + 1)

        if (epoch + 1) % p.logging.log_interval_train == 0:
            for (images, labels) in data.data_test:
                test_step(images, labels)
            elbo_test = metric_mean.result()
            metric_mean.reset_states()
            for (images, labels) in data.data_train:
                test_step(images, labels)
            elbo_train = metric_mean.result()
            metric_mean.reset_states()
            with summary_writer.as_default():
                tf.summary.scalar('elbo_test', elbo_test, epoch + 1)
                tf.summary.scalar('elbo_train', elbo_train, epoch + 1)
            print('epoch={}, elbo_train={}, elbo_test={}'.format(
                epoch + 1, elbo_train, elbo_test))
        if (epoch + 1) % p.logging.log_interval_test == 0:
            for (images, labels) in data.data_test:
                test_step(images, labels, iw_k=p.training.test_iw_k)
            ll_test = metric_mean.result()
            metric_mean.reset_states()
            for (images, labels) in data.data_train:
                test_step(images, labels, iw_k=p.training.test_iw_k)
            ll_train = metric_mean.result()
            metric_mean.reset_states()
            with summary_writer.as_default():
                tf.summary.scalar('ll_test', ll_test, epoch + 1)
                tf.summary.scalar('ll_train', ll_train, epoch + 1)
            samples_gen = analyze()
            with summary_writer.as_default():
                tf.summary.image('samples_gen', samples_gen, epoch + 1)
            print('epoch={}, ll_train={}, ll_test={}'.format(
                epoch + 1, ll_train, ll_test))
        if (epoch + 1) % p.logging.save_every_epochs == 0:
            ckpt_manager.save()
コード例 #3
0
def main(**main_args):
    begin_time = time.time()

    # init args
    args.update(**main_args)
    command_line_args = parse_args()
    args.setdefault(**command_line_args)

    args.update(run_on_yard=True)

    seed = args.seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

    # get Model, set model default args
    Model = vars(models)[args.model]
    args.setdefault(**Model.args.vars())

    if args.run_test:
        args.update(epochs=2, nb_vali_step=2, max_data_line=100)

    print(args)

    # get data
    random.seed(seed)
    np.random.seed(args.seed)
    data = dataset.Data()
    min_epochs = args.nb_train / (args.batch_size * args.nb_vali_step)
    if min_epochs < 1.0:
        args.update(nb_vali_step=int(np.ceil(args.nb_train / args.batch_size)))
        print(args)
        min_epochs = args.nb_train / (args.batch_size * args.nb_vali_step)
    args.update(min_epochs=int(np.ceil(min_epochs)))
    # args.setdefault())

    # run_name: time-x-Modes-ds
    time_str = utils.get_time_str()
    model_name = Model.__name__
    run_name = f'{time_str}-{model_name}-{args.ds}'
    if args.msg:
        run_name = f'{run_name}-{args.msg}'
    if args.run_test:
        run_name = f'{run_name}-test'

    args.update(run_name=run_name)
    T = Train.Train(Model, data)

    log_fn = f'{utils.log_dir}/{run_name}.log'
    begin_time_str = utils.get_time_str()
    print(begin_time_str, log_fn, '----- start!, pid:', os.getpid())
    args.update(pid=os.getpid())
    log = utils.Logger(fn=log_fn, verbose=args.verbose)
    args.update(log=log)
    args.log.log(f'argv: {" ".join(sys.argv)}')
    args.log.log(f'log_fn: {log_fn}')
    args.log.log(f'args: {args.prt_json()}')
    args.log.log(f'Model: {model_name}')
    args.log.log(f'begin time: {begin_time_str}')

    try:
        T.train()
    except KeyboardInterrupt as e:
        if not T.has_train:
            raise e
    test_str = T.final_test()

    args.log.log(f'\ntest: {test_str}\n', red=True)

    args.log.log(log_fn)
    dt = time.time() - begin_time
    end_time_str = utils.get_time_str()
    args.log.log(f'end time: {end_time_str}, dt: {dt / 3600:.2f}h')
    print(end_time_str, log_fn, f'##### over, time: {dt / 3600:.2f}h')
コード例 #4
0
ファイル: main.py プロジェクト: eksuas/DBMSystemsBenchmark
def read_files(args):
    try:
        # Open files
        films_file=open(args.films_path,"r")
        collectors_file=open(args.collectors_path,"r")
        collect_file=open(args.collect_path,"r")
        follow_file=open(args.follow_path,"r")

    except IOError:
        print("The files are not found.")
        sys.exit()

    # Create and initialize local variables
    data = dataset.Data()

    # Firstly, read informations in FILMS file
    films_lines=films_file.readlines()
    for i in xrange(len(films_lines)):
        line=films_lines[i].split(' % ')
        # Create a new movie
        movie = dataset.Movie(
            ID=line[0],
            title=line[1],
            year=line[2],
            genre=line[4],
            director=line[5],
            rating=line[6],
        )
        # identify the actors of the movie
        movie.actors=Set(line[3].split(', '))
        # add the movie, actors and directors to dataset
        data.movies.add(movie)
        data.actors.update(movie.actors)
        data.directors.add(movie.director)

    # Assign id to actors and directors
    data.actors=Set([dataset.Actor(ID+2001, name) for ID, name in enumerate(data.actors)])
    data.directors=Set([dataset.Director(ID+3001, name) for ID, name in enumerate(data.directors)])

    # Read lines in collectors file
    collectors_lines=collectors_file.readlines()
    for i in xrange(len(collectors_lines)):
        line=collectors_lines[i].split('%')
        collector = dataset.Collector(
            ID=line[0],
            name=line[1],
            email=line[2]
        )
        # Add collector into collectors set in data class
        data.collectors.add(collector)

    collect_lines=collect_file.readlines()
    for i in xrange(len(collect_lines)):
        line=collect_lines[i].strip().split('%')
        data.collectings.add((line[0], line[1]))

    follow_lines=follow_file.readlines()
    for i in xrange(len(follow_lines)):
        line=follow_lines[i].strip().split('%')
        data.followings.append((line[0],line[1]))

    data.questions_list.append("1. List all actors ( userid and fullname ) who are also directors")
    data.questions_list.append("2. List all actors ( userid and fullname ) who acted in 5 or more movies.")
    data.questions_list.append("3. How many actors have acted in same movies with ’Edward Norton’?")
    data.questions_list.append("4. Which collectors collected all movies in which ’Edward Norton’ acts?")
    data.questions_list.append("5. List 10 collectors ( userid and fullname ) who collect ’The Shawshank Redemption’.")
    data.questions_list.append("6. List all userids and fullnames of xi’s which satisfy Degree(1001, xi) ≤ 3")

    return data
コード例 #5
0
import dataset
import neural_network

import time

import numpy as np
import scipy as sc
import matplotlib.pyplot as plt

from sklearn.datasets import make_circles
from IPython.display import clear_output

nData = 500
nFeaturesPattern = 2

dt = dataset.Data(nData)
dt.buildDataInCircles()
# dt.showDataGraph()
X, Y = dt.getData()

topology = [nFeaturesPattern, 4, 8, 1]
nn = neural_network.NN(topology)

iteration = 25000
loss = [1]

for i in range(iteration):
    pY = nn.trainingNeuralNetwork(X, Y, learning_factor=0.5)
    error = nn.functionCost[0](nn.out[-1][1], Y)

    if error < loss[-1]:
コード例 #6
0
def main(**main_args):
    begin_time = time.time()

    # init args
    args.clear()
    command_line_args = parse_args()
    args.update(**command_line_args)
    args.update(**main_args)

    if args.ds == 'test':
        args.update(run_test=True)

    seed = args.seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)


    # get Model, set model default args
    Model = vars(models)[args.model]
    args.update(**Model.args)
    args.update(**main_args)

    if args.run_test:
        args.update(epochs=2, nb_vali_step=2, batch_size=4)

    # get data
    data = dataset.Data()
    min_epochs = args.nb_train / (args.batch_size * args.nb_vali_step)

    if min_epochs < 1.0:
        args.update(nb_vali_step=int(np.ceil(args.nb_train / args.batch_size)))
        min_epochs = args.nb_train / (args.batch_size * args.nb_vali_step)
    args.update(min_epochs=int(np.ceil(min_epochs)))

    # run_name: time-x-Model-ds
    model_name = Model.__name__
    time_str = utils.get_time_str()
    run_name = f'{time_str}-{model_name}-{args.ds}'

    if args.msg:
        run_name = f'{run_name}-{args.msg}'
    if args.run_test:
        run_name = f'{run_name}-test'
    if args.restore_model:
        run_name = f'{run_name}-restored'

    args.update(run_name=run_name)

    log_fn = f'{utils.log_dir}/{run_name}.log'
    begin_time_str = utils.get_time_str()
    args.update(pid=os.getpid())
    log = utils.Logger(fn=log_fn, verbose=args.verbose)
    args.update(log=log)
    args.log.log(f'argv: {" ".join(sys.argv)}')
    args.log.log(f'log_fn: {log_fn}')
    args.log.log(f'main_args: {utils.Object(**main_args).json()}')
    args.log.log(f'args: {args.json()}')
    args.log.log(f'Model: {model_name}')
    args.log.log(f'begin time: {begin_time_str}')

    T = Train.Train(Model, data)
    if args.restore_model:
        T.model.restore_from_other(args.restore_model)
    if not args.restore_model or args.restore_train:
        try:
            T.train()
        except KeyboardInterrupt as e:
            pass
        T.model.restore(0)

    if args.skip_vali:
        test_str = 'none'
    else:
        test_str = T.final_test()
        args.log.log(f'vali: {test_str}', red=True)

    if args.dump_all:
        T.dump_features_all_item('vali')
        T.dump_features_all_item('test')
        return

    if args.dump:
        T.dump_features('vali')
        T.dump_features('test')
        return

    args.log.log(run_name, red=True)
    if args.restore_model:
        args.log.log(f'restored from {args.restore_model}', red=True)

    dt = time.time() - begin_time
    end_time_str = utils.get_time_str()
    args.log.log(f'end time: {end_time_str}, dt: {dt / 3600:.2f}h')
    print(end_time_str, log_fn, f'##### over, time: {dt / 3600:.2f}h')

    return test_str, time_str
コード例 #7
0
ファイル: model.py プロジェクト: nsauder/DeepKnowledgeTracing
updates = lasagne.updates.adam(cost,
                               rnn.params,
)

train_fn = theano.function(inputs=[x, indices, labels],
                           outputs=cost,
                           updates=updates)

test_fn = theano.function(inputs=[x, indices, labels],
                          outputs=[cost, preds])



# #################### Data Flow  ####################

data = dataset.Data(params)
train_ds = data.dataset()
test_ds = data.dataset(is_test=True)


with test_ds as test_gen:
    test_chunks = list(test_gen)
 
    
with train_ds as train_gen:
    enum_gen = enumerate(train_gen)
    while True:
        iter_num, data_batch = enum_gen.next()

        if iter_num > params.num_iterations:
            break
コード例 #8
0
ファイル: train.py プロジェクト: gagiter/CombinedDepth
def train():
    use_cuda = torch.cuda.is_available() and not args.no_cuda
    device = torch.device('cuda' if use_cuda else 'cpu')

    train_data = dataset.Data(args.data_root,
                              target_pixels=args.target_pixels,
                              target_width=args.target_width,
                              target_height=args.target_height,
                              use_number=args.use_number,
                              device=device)
    train_loader = DataLoader(train_data,
                              batch_size=args.mini_batch_size,
                              shuffle=True)

    model = Model(args.encoder,
                  rotation_scale=args.rotation_scale,
                  translation_scale=args.translation_scale,
                  depth_scale=args.depth_scale)

    criteria = Criteria(depth_weight=args.depth_weight,
                        regular_weight=args.regular_weight,
                        ref_weight=args.ref_weight,
                        ground_weight=args.ground_weight,
                        scale_weight=args.scale_weight,
                        average_weight=args.average_weight,
                        down_times=args.down_times,
                        warp_flag=args.warp_flag,
                        average_depth=args.average_depth,
                        regular_flag=args.regular_flag,
                        sigma_scale=args.sigma_scale)
    pcd = o3d.geometry.PointCloud()

    model = model.to(device)
    optimiser = torch.optim.Adam(model.parameters(), lr=args.lr)

    load_dir = os.path.join('checkpoint', args.model_name)
    if args.resume > 0 and os.path.exists(load_dir):
        model.load_state_dict(torch.load(os.path.join(load_dir, 'model.pth')))
        optimiser.load_state_dict(
            torch.load(os.path.join(load_dir, 'optimiser.pth')))
        if os.path.exists(os.path.join(load_dir, 'step.pth')):
            args.step_start = torch.load(os.path.join(load_dir,
                                                      'step.pth'))['step']
        if os.path.exists(os.path.join(load_dir, 'sigma.pth')):
            sigma = torch.load(os.path.join(load_dir, 'sigma.pth'))
            criteria.previous_sigma = sigma['previous_sigma']
            criteria.next_sigma = sigma['next_sigma']

    date_time = datetime.now().strftime("_%Y_%m_%d_%H_%M_%S")
    writer = SummaryWriter(os.path.join('runs', args.model_name + date_time))
    writer.add_text('args', str(args), 0)
    model.train()
    losses = []
    data_iter = iter(train_loader)
    for step in range(args.step_start, args.step_start + args.step_number):
        try:
            data_in = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            data_in = next(data_iter)
        data_out = model(data_in)
        loss = criteria(data_in, data_out)
        loss.backward()
        losses.append(loss.item())
        if step % (args.batch_size / args.mini_batch_size) == 0:
            optimiser.step()
            optimiser.zero_grad()

        if step % args.summary_freq == 0:
            loss = sum(losses) / len(losses)
            print('step:%d loss:%f' % (step, loss))
            util.visualize(data_in)
            util.visualize(data_out)
            writer.add_scalar('loss', loss, global_step=step)
            writer.add_image('image/image',
                             data_in['image'][0],
                             global_step=step)
            writer.add_image('image/color_map',
                             data_in['color_map'][0],
                             global_step=step)
            writer.add_image('image/normal',
                             data_out['normal_v'][0],
                             global_step=step)
            writer.add_text('camera',
                            str(data_out['camera'][0].data.cpu().numpy()),
                            global_step=step)
            if 'depth_v' in data_in:
                writer.add_image('image/depth_in',
                                 data_in['depth'][0],
                                 global_step=step)
            if 'depth_v' in data_out:
                writer.add_image('image/depth_out',
                                 data_out['depth'][0],
                                 global_step=step)
            if 'ground' in data_out:
                writer.add_text('ground',
                                str(data_out['ground'][0].data.cpu().numpy()),
                                global_step=step)
            for key in data_out:
                if key.startswith('base_'):
                    writer.add_image('image/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('image_'):
                    writer.add_image('image/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('residual_'):
                    writer.add_image('residual/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('warp_'):
                    writer.add_image('warp/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('grad_'):
                    writer.add_image('grad/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('regular_'):
                    writer.add_image('regular/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('record_'):
                    writer.add_image('record/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('ground_'):
                    writer.add_image('ground/' + key,
                                     data_out[key][0],
                                     global_step=step)
                elif key.startswith('loss'):
                    writer.add_scalar('loss/' + key,
                                      data_out[key],
                                      global_step=step)
                elif key.startswith('eval_'):
                    writer.add_scalar('eval/' + key,
                                      data_out[key],
                                      global_step=step)
                elif key.startswith('motion'):
                    writer.add_text('motion/' + key,
                                    str(data_out[key][0].data.cpu().numpy()),
                                    global_step=step)
            losses = []

        if step % args.save_freq == 0:
            save_dir = os.path.join('checkpoint', args.model_name)
            os.makedirs(save_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
            torch.save(optimiser.state_dict(),
                       os.path.join(save_dir, 'optimiser.pth'))
            torch.save({'step': step}, os.path.join(save_dir, 'step.pth'))
            torch.save(
                {
                    'previous_sigma': criteria.previous_sigma,
                    'next_sigma': criteria.next_sigma
                }, os.path.join(save_dir, 'sigma.pth'))

            points = data_out['points'][0].data.cpu().numpy()
            points = points.transpose(1, 2, 0).reshape(-1, 3)
            pcd.points = o3d.utility.Vector3dVector(points)
            colors = data_in['image'][0].data.cpu().numpy()
            colors = colors.transpose(1, 2, 0).reshape(-1, 3)
            pcd.colors = o3d.utility.Vector3dVector(colors)
            o3d.io.write_point_cloud(
                os.path.join(save_dir,
                             '%s-%010d.pcd' % (args.model_name, step)), pcd)
            print('saved to ' + save_dir)

    writer.close()
コード例 #9
0
ファイル: train.py プロジェクト: youngflyasd/poi2vec
    # POI2VEC
    loss = p2v_model(user, context, target)

    loss.backward()
    optimizer.step()
    gc.collect()

    return loss.data.cpu().numpy()[0]


##############################################################################################
##############################################################################################
if __name__ == "__main__":

    # Data Preparation
    data = dataset.Data()
    poi_cnt, user_cnt = data.load()

    # Model Preparation
    p2v_model = models.POI2VEC(poi_cnt, user_cnt, data.id2route, data.id2lr,
                               data.id2prob).cuda()
    loss_model = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(parameters(p2v_model),
                                lr=config.learning_rate,
                                momentum=config.momentum)

    for i in xrange(config.num_epochs):
        # Training
        batch_loss = 0.
        train_batches = data.train_batch_iter(config.batch_size)
        for j, train_batch in enumerate(train_batches):
コード例 #10
0
ファイル: test.py プロジェクト: gagiter/CombinedDepth
parser.add_argument('--encoder', type=str, default='mobilenet_v2')
parser.add_argument('--model_name', type=str, default='model')
parser.add_argument('--no_cuda', action='store_true', default=False)
parser.add_argument('--gpu_id', type=str, default='1')
parser.add_argument('--target_pixels', type=int, default=300000)
parser.add_argument('--target_width', type=int, default=640)
parser.add_argument('--target_height', type=int, default=480)

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
use_cuda = torch.cuda.is_available() and not args.no_cuda
device = torch.device('cuda' if use_cuda else 'cpu')

data = dataset.Data(args.data_root,
                    target_pixels=args.target_pixels,
                    target_width=args.target_width,
                    target_height=args.target_height,
                    device=device)
loader = DataLoader(data, batch_size=1, shuffle=False)

model = Model(args.encoder)
model = model.to(device)
load_dir = os.path.join('checkpoint', args.model_name)
model.load_state_dict(torch.load(os.path.join(load_dir, 'model.pth')))
model.eval()

date_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_")
writer = SummaryWriter(os.path.join('test', date_time + args.model_name))
writer.add_text('args', str(args), 0)

with torch.no_grad():
コード例 #11
0
def PreprocessData(imageTransforms, slices, rootDir, inputDir, segmDir, batchSize, shuffleFlag) :
	#self, transforms, root_dir, output_dirs, input_dir, segm_dir
	dataLoad = DataUtils.Data(imageTransforms, root_dir = rootDir, output_dirs = slices, input_dir = inputDir, segm_dir = segmDir)
	dataLoader = torch.utils.data.DataLoader(dataLoad, batch_size = batchSize, shuffle = shuffleFlag)
	return dataLoader
コード例 #12
0
parser.add_argument('--summary_freq', type=int, default=10, metavar='N',
                    help='how frequency to summary')
parser.add_argument('--save_freq', type=int, default=100, metavar='N',
                    help='how frequency to save')
parser.add_argument('--eval_freq', type=int, default=100, metavar='N',
                    help='how frequency to eval')
parser.add_argument('--image_size', type=int, default=512,
                    help='image size to train')


args = parser.parse_args()

use_cuda = torch.cuda.is_available() and not args.no_cuda
device = torch.device('cuda:%d' % args.gpu_id if use_cuda else 'cpu')

train_data = dataset.Data(os.path.join('data', args.dataset), size=args.image_size, mode='train', device=device)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)

val_data = dataset.Data(os.path.join('data', args.dataset), size=args.image_size, mode='val')
val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False)

model = Model()
model = model.to(device)
criterion = Criterion()
optimiser = optim.Adam(model.parameters(), lr=args.lr)

date_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_")
writer = SummaryWriter(os.path.join('runs', date_time + args.model_name))

load_dir = os.path.join('checkpoint', args.model_name)
if args.resume > 0 and os.path.exists(load_dir):