Image DHG

To convert the DHG/SREC dataset to TSSI image representations:

import numpy
import matplotlib

matplotlib.rcParams['figure.figsize'] = [12.0, 8.0]
import torch
import tqdm
from data import load_data

device_name = 'cpu'
device = torch.device(device_name)

config = {
    'device': device_name,
    'batch_size': 256,
    'data_pickle_string': '/home/caor/datasets/shrec_data.pckl',
    'image_folder_string': '/home/caor/datasets/dataset_imageseq_dhg1428',  # without / at the end
    'n_classes': 14
}

x_train, x_test, y_train_14, y_train_28, y_test_14, y_test_28 = load_data(filepath=config['data_pickle_string'])

# for simpler visualizations we can group all the Xs, all the Ys and all the Zs together
x_train = numpy.array([numpy.hstack([xi[:, 0::3], xi[:, 1::3], xi[:, 2::3]]) for xi in x_train])
x_test = numpy.array([numpy.hstack([xi[:, 0::3], xi[:, 1::3], xi[:, 2::3]]) for xi in x_test])

x_train = torch.from_numpy(x_train).float()
x_test = torch.from_numpy(x_test).float()

# Pytorch expects labels between 0 and N-1
y_train_14 = torch.Tensor(y_train_14) - 1
y_train_28 = torch.Tensor(y_train_28) - 1
y_test_14 = torch.Tensor(y_test_14) - 1
y_test_28 = torch.Tensor(y_test_28) - 1

# Full dataset
x_dataset = torch.cat([x_train, x_test], dim=0)
y_14_dataset = torch.cat([y_train_14, y_test_14], dim=0)
y_28_dataset = torch.cat([y_train_28, y_test_28], dim=0)

# Normalize the values
min_ = x_dataset.min(0)[0].min(0)[0]
max_ = x_dataset.max(0)[0].max(0)[0]
x_dataset_std = (x_dataset - min_) / (max_ - min_)
x_dataset_std_3 = x_dataset_std.reshape(-1, 100, 22, 3)  # shape: batch, duration, joints, joint_component


def sequence_to_image(seq, size=(256, 256), return_tensor_shaped_as_CHW=True):

    """
    bones_structure = [
        [0, 1],  # wrist palm
        [0, 2],  # wrist thumb
        [1, 6],  # palm index
        [1, 10],  # palm middle
        [1, 14],  # palm ring
        [1, 18],  # palm pinky
        [2, 3], [3, 4], [4, 5],  # thumb
        [6, 7], [7, 8], [8, 9],  # index
        [10, 11], [11, 12], [12, 13],  # middle
        [14, 15], [15, 16], [16, 17],  # ring
        [18, 19], [19, 20], [20, 21]  # pinky
    ]
    """

    # note: you can check the `spatial_order` with the `bones_structure`
    spatial_order = "0-2-3-4-5-4-3-2-0-1-6-7-8-9-8-7-6-1-10-11-12-13-" \
                    "12-11-10-1-14-15-16-17-16-15-14-1-18-19-20-21-20-19-18-1-0".split('-')
    spatial_order = [int(s) for s in spatial_order]

    # input shape: (time, joints, 3), where 3 stands for xyz
    assert len(seq.shape) == 3
    assert seq.shape[-1] == 3

    # shape: (1, T, S, 3) where S = 43
    out_image = seq[:, spatial_order, :].unsqueeze(0)

    out_image = out_image.transpose(1, 3)
    out_image = out_image.transpose(2, 3)

    # size = (temporal, spatial)
    out_image = torch.nn.functional.interpolate(out_image, size=size, mode='nearest')
    out_image = out_image.transpose(2, 3)
    out_image = out_image.transpose(1, 3)

    out_image = out_image.squeeze(0)

    if return_tensor_shaped_as_CHW is True:
        out_image = out_image.permute(2, 0, 1)

    # output shape: (3, size[0], size[1])
    #         where size[0] = duration (= 256, interpolated from 100) and
    #               size[1] = spatial (= 256, interpolated from 43)
    return out_image


def my_save_image(tensor, fp):
    from PIL import Image
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    im.save(fp, format='JPEG', subsampling=0, quality=100)


# ---------------------------
# Generate the images!
# ---------------------------
print(len(x_dataset_std_3))  # 2800 images
for i, (x, y14, y28) in enumerate(tqdm.tqdm(zip(x_dataset_std_3, y_14_dataset, y_28_dataset))):
    # note: the Ys labels already have been substracted by 1 (y = y - 1). Change that if you want.
    folder = '{}/seq_{:04d}__y14_{:02d}__y28_{:02d}.jpg'
    ximage = sequence_to_image(x, size=(512, 512))
    my_save_image(ximage, folder.format(config['image_folder_string'], i, int(y14.item()), int(y28.item())))
print('Saved dataset as images')